diff --git a/SConstruct b/SConstruct index e866a4c998..0c91399a12 100644 --- a/SConstruct +++ b/SConstruct @@ -598,6 +598,13 @@ if selected_platform in platform_list: ) } ) + env.Append( + BUILDERS={ + "GLSL_HEADER": env.Builder( + action=run_in_subprocess(gles_builders.build_raw_headers), suffix="glsl.gen.h", src_suffix=".glsl" + ) + } + ) scons_cache_path = os.environ.get("SCONS_CACHE") if scons_cache_path != None: diff --git a/core/error_macros.h b/core/error_macros.h index 18c46c9e7d..83f92129a5 100644 --- a/core/error_macros.h +++ b/core/error_macros.h @@ -530,19 +530,19 @@ void _err_print_index_error(const char *p_function, const char *p_file, int p_li * Prints `m_msg`. */ #define ERR_PRINT(m_msg) \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg)) + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg) /** * Prints `m_msg` once during the application lifetime. */ -#define ERR_PRINT_ONCE(m_msg) \ - if (1) { \ - static bool first_print = true; \ - if (first_print) { \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg)); \ - first_print = false; \ - } \ - } else \ +#define ERR_PRINT_ONCE(m_msg) \ + if (1) { \ + static bool first_print = true; \ + if (first_print) { \ + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg); \ + first_print = false; \ + } \ + } else \ ((void)0) // Print warning message macros. @@ -553,21 +553,21 @@ void _err_print_index_error(const char *p_function, const char *p_file, int p_li * If warning about deprecated usage, use `WARN_DEPRECATED` or `WARN_DEPRECATED_MSG` instead. */ #define WARN_PRINT(m_msg) \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg), ERR_HANDLER_WARNING) + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg, ERR_HANDLER_WARNING) /** * Prints `m_msg` once during the application lifetime. * * If warning about deprecated usage, use `WARN_DEPRECATED` or `WARN_DEPRECATED_MSG` instead. */ -#define WARN_PRINT_ONCE(m_msg) \ - if (1) { \ - static bool first_print = true; \ - if (first_print) { \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg), ERR_HANDLER_WARNING); \ - first_print = false; \ - } \ - } else \ +#define WARN_PRINT_ONCE(m_msg) \ + if (1) { \ + static bool first_print = true; \ + if (first_print) { \ + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg, ERR_HANDLER_WARNING); \ + first_print = false; \ + } \ + } else \ ((void)0) // Print deprecated warning message macros. diff --git a/core/image.cpp b/core/image.cpp index 6f18516ae1..ff8acc54af 100644 --- a/core/image.cpp +++ b/core/image.cpp @@ -3668,6 +3668,10 @@ Ref Image::duplicate(bool p_subresources) const { return copy; } +void Image::set_as_black() { + zeromem(data.ptrw(), data.size()); +} + Image::Image() { width = 0; diff --git a/core/image.h b/core/image.h index 5bd73fa677..07b4f49751 100644 --- a/core/image.h +++ b/core/image.h @@ -376,6 +376,8 @@ public: void set_pixelv(const Point2 &p_dst, const Color &p_color); void set_pixel(int p_x, int p_y, const Color &p_color); + void set_as_black(); + void copy_internals_from(const Ref &p_image) { ERR_FAIL_COND_MSG(p_image.is_null(), "It's not a reference to a valid Image object."); format = p_image->format; diff --git a/core/io/resource_format_binary.cpp b/core/io/resource_format_binary.cpp index 8c7559479b..e0fea143bb 100644 --- a/core/io/resource_format_binary.cpp +++ b/core/io/resource_format_binary.cpp @@ -337,10 +337,14 @@ Error ResourceLoaderBinary::parse_variant(Variant &r_v) { } break; case OBJECT_INTERNAL_RESOURCE: { uint32_t index = f->get_32(); + String path = res_path + "::" + itos(index); + if (use_nocache) { - r_v = internal_resources[index].cache; + if (!internal_index_cache.has(path)) { + WARN_PRINT(String("Couldn't load resource (no cache): " + path).utf8().get_data()); + } + r_v = internal_index_cache[path]; } else { - String path = res_path + "::" + itos(index); RES res = ResourceLoader::load(path); if (res.is_null()) { WARN_PRINT(String("Couldn't load resource: " + path).utf8().get_data()); @@ -720,13 +724,15 @@ Error ResourceLoaderBinary::load() { if (!main) { + path = internal_resources[i].path; + + if (path.begins_with("local://")) { + path = path.replace_first("local://", ""); + subindex = path.to_int(); + path = res_path + "::" + path; + } + if (!use_nocache) { - path = internal_resources[i].path; - if (path.begins_with("local://")) { - path = path.replace_first("local://", ""); - subindex = path.to_int(); - path = res_path + "::" + path; - } if (ResourceCache::has(path)) { //already loaded, don't do anything @@ -769,7 +775,7 @@ Error ResourceLoaderBinary::load() { r->set_subindex(subindex); if (!main) { - internal_resources.write[i].cache = res; + internal_index_cache[path] = res; } int pc = f->get_32(); diff --git a/core/io/resource_format_binary.h b/core/io/resource_format_binary.h index 0f8fc9445b..3c8d916c0a 100644 --- a/core/io/resource_format_binary.h +++ b/core/io/resource_format_binary.h @@ -68,10 +68,10 @@ class ResourceLoaderBinary { struct IntResource { String path; uint64_t offset; - RES cache; }; Vector internal_resources; + Map internal_index_cache; String get_unicode_string(); void _advance_padding(uint32_t p_len); diff --git a/core/io/resource_importer.cpp b/core/io/resource_importer.cpp index 643df53f8c..9e22bdced7 100644 --- a/core/io/resource_importer.cpp +++ b/core/io/resource_importer.cpp @@ -91,7 +91,7 @@ Error ResourceFormatImporter::_get_path_and_type(const String &p_path, PathAndTy r_path_and_type.path = value; path_found = true; //first match must have priority } else if (assign == "type") { - r_path_and_type.type = value; + r_path_and_type.type = ClassDB::get_compatibility_remapped_class(value); } else if (assign == "importer") { r_path_and_type.importer = value; } else if (assign == "group_file") { diff --git a/core/local_vector.h b/core/local_vector.h new file mode 100644 index 0000000000..0b0ef6dfdc --- /dev/null +++ b/core/local_vector.h @@ -0,0 +1,246 @@ +/*************************************************************************/ +/* local_vector.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LOCAL_VECTOR_H +#define LOCAL_VECTOR_H + +#include "core/error_macros.h" +#include "core/os/copymem.h" +#include "core/os/memory.h" +#include "core/sort_array.h" +#include "core/vector.h" + +template +class LocalVector { +private: + U count = 0; + U capacity = 0; + T *data = nullptr; + +public: + _FORCE_INLINE_ void push_back(T p_elem) { + if (unlikely(count == capacity)) { + if (capacity == 0) { + capacity = 1; + } else { + capacity <<= 1; + } + data = (T *)memrealloc(data, capacity * sizeof(T)); + CRASH_COND_MSG(!data, "Out of memory"); + } + + if (!__has_trivial_constructor(T) && !force_trivial) { + memnew_placement(&data[count++], T(p_elem)); + } else { + data[count++] = p_elem; + } + } + + void remove(U p_index) { + ERR_FAIL_UNSIGNED_INDEX(p_index, count); + for (U i = p_index; i < count; i++) { + data[i] = data[i + 1]; + } + count--; + if (!__has_trivial_destructor(T) && !force_trivial) { + data[count].~T(); + } + } + + void erase(const T &p_val) { + U idx = find(p_val); + if (idx >= 0) + remove(idx); + } + + void invert() { + + for (U i = 0; i < count / 2; i++) { + SWAP(data[i], data[count - i - 1]); + } + } + + _FORCE_INLINE_ void clear() { resize(0); } + _FORCE_INLINE_ void reset() { + clear(); + if (data) { + memfree(data); + data = nullptr; + capacity = 0; + } + } + _FORCE_INLINE_ bool empty() const { return count == 0; } + _FORCE_INLINE_ void reserve(U p_size) { + p_size = nearest_power_of_2_templated(p_size); + if (p_size > capacity) { + capacity = p_size; + data = (T *)memrealloc(data, capacity * sizeof(T)); + CRASH_COND_MSG(!data, "Out of memory"); + } + } + + _FORCE_INLINE_ U size() const { return count; } + void resize(U p_size) { + + if (p_size < count) { + + if (!__has_trivial_destructor(T) && !force_trivial) { + for (U i = p_size; i < count; i++) { + data[i].~T(); + } + } + count = p_size; + } else if (p_size > count) { + + if (unlikely(p_size > capacity)) { + if (capacity == 0) { + capacity = 1; + } + while (capacity < p_size) { + capacity <<= 1; + } + data = (T *)memrealloc(data, capacity * sizeof(T)); + CRASH_COND_MSG(!data, "Out of memory"); + } + if (!__has_trivial_constructor(T) && !force_trivial) { + for (U i = count; i < p_size; i++) { + memnew_placement(&data[i], T); + } + } + count = p_size; + } + } + _FORCE_INLINE_ const T &operator[](U p_index) const { + CRASH_BAD_UNSIGNED_INDEX(p_index, count); + return data[p_index]; + } + _FORCE_INLINE_ T &operator[](U p_index) { + CRASH_BAD_UNSIGNED_INDEX(p_index, count); + return data[p_index]; + } + + void insert(U p_pos, T p_val) { + ERR_FAIL_UNSIGNED_INDEX(p_pos, count + 1); + if (p_pos == count) { + push_back(p_val); + } else { + resize(count + 1); + for (U i = count; i > p_pos; i--) { + data[i] = data[i - 1]; + } + data[p_pos] = p_val; + } + } + + int64_t find(const T &p_val, U p_from = 0) const { + + for (U i = 0; i < count; i++) { + if (data[i] == p_val) { + return int64_t(i); + } + } + return -1; + } + + template + void sort_custom() { + + U len = count; + if (len == 0) + return; + + SortArray sorter; + sorter.sort(data, len); + } + + void sort() { + + sort_custom<_DefaultComparator>(); + } + + void ordered_insert(T p_val) { + + U i; + for (i = 0; i < count; i++) { + + if (p_val < data[i]) { + break; + }; + }; + insert(i, p_val); + } + + operator Vector() const { + Vector ret; + ret.resize(size()); + T *w = ret.ptrw(); + copymem(w, data, sizeof(T) * count); + return ret; + } + + Vector to_byte_array() const { //useful to pass stuff to gpu or variant + Vector ret; + ret.resize(count * sizeof(T)); + uint8_t *w = ret.ptrw(); + copymem(w, data, sizeof(T) * count); + return ret; + } + + _FORCE_INLINE_ LocalVector() {} + _FORCE_INLINE_ LocalVector(const LocalVector &p_from) { + resize(p_from.size()); + for (U i = 0; i < p_from.count; i++) { + data[i] = p_from.data[i]; + } + } + inline LocalVector &operator=(const LocalVector &p_from) { + resize(p_from.size()); + for (U i = 0; i < p_from.count; i++) { + data[i] = p_from.data[i]; + } + return *this; + } + inline LocalVector &operator=(const Vector &p_from) { + resize(p_from.size()); + for (U i = 0; i < count; i++) { + data[i] = p_from[i]; + } + return *this; + } + + _FORCE_INLINE_ ~LocalVector() { + + if (data) { + reset(); + } + } +}; + +#endif // LOCAL_VECTOR_H diff --git a/core/math/basis.cpp b/core/math/basis.cpp index 87abf2dbc1..6218b7e248 100644 --- a/core/math/basis.cpp +++ b/core/math/basis.cpp @@ -878,3 +878,114 @@ Basis Basis::slerp(const Basis &target, const real_t &t) const { return b; } + +void Basis::rotate_sh(real_t *p_values) { + + // code by John Hable + // http://filmicworlds.com/blog/simple-and-fast-spherical-harmonic-rotation/ + // this code is Public Domain + + const static real_t s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi)) + const static real_t s_c4 = -0.31539156525; // (-sqrt(5))/(4*sqrt(pi)) + const static real_t s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi)) + + const static real_t s_c_scale = 1.0 / 0.91529123286551084; + const static real_t s_c_scale_inv = 0.91529123286551084; + + const static real_t s_rc2 = 1.5853309190550713 * s_c_scale; + const static real_t s_c4_div_c3 = s_c4 / s_c3; + const static real_t s_c4_div_c3_x2 = (s_c4 / s_c3) * 2.0; + + const static real_t s_scale_dst2 = s_c3 * s_c_scale_inv; + const static real_t s_scale_dst4 = s_c5 * s_c_scale_inv; + + real_t src[9] = { p_values[0], p_values[1], p_values[2], p_values[3], p_values[4], p_values[5], p_values[6], p_values[7], p_values[8] }; + + real_t m00 = elements[0][0]; + real_t m01 = elements[0][1]; + real_t m02 = elements[0][2]; + real_t m10 = elements[1][0]; + real_t m11 = elements[1][1]; + real_t m12 = elements[1][2]; + real_t m20 = elements[2][0]; + real_t m21 = elements[2][1]; + real_t m22 = elements[2][2]; + + p_values[0] = src[0]; + p_values[1] = m11 * src[1] - m12 * src[2] + m10 * src[3]; + p_values[2] = -m21 * src[1] + m22 * src[2] - m20 * src[3]; + p_values[3] = m01 * src[1] - m02 * src[2] + m00 * src[3]; + + real_t sh0 = src[7] + src[8] + src[8] - src[5]; + real_t sh1 = src[4] + s_rc2 * src[6] + src[7] + src[8]; + real_t sh2 = src[4]; + real_t sh3 = -src[7]; + real_t sh4 = -src[5]; + + // Rotations. R0 and R1 just use the raw matrix columns + real_t r2x = m00 + m01; + real_t r2y = m10 + m11; + real_t r2z = m20 + m21; + + real_t r3x = m00 + m02; + real_t r3y = m10 + m12; + real_t r3z = m20 + m22; + + real_t r4x = m01 + m02; + real_t r4y = m11 + m12; + real_t r4z = m21 + m22; + + // dense matrix multiplication one column at a time + + // column 0 + real_t sh0_x = sh0 * m00; + real_t sh0_y = sh0 * m10; + real_t d0 = sh0_x * m10; + real_t d1 = sh0_y * m20; + real_t d2 = sh0 * (m20 * m20 + s_c4_div_c3); + real_t d3 = sh0_x * m20; + real_t d4 = sh0_x * m00 - sh0_y * m10; + + // column 1 + real_t sh1_x = sh1 * m02; + real_t sh1_y = sh1 * m12; + d0 += sh1_x * m12; + d1 += sh1_y * m22; + d2 += sh1 * (m22 * m22 + s_c4_div_c3); + d3 += sh1_x * m22; + d4 += sh1_x * m02 - sh1_y * m12; + + // column 2 + real_t sh2_x = sh2 * r2x; + real_t sh2_y = sh2 * r2y; + d0 += sh2_x * r2y; + d1 += sh2_y * r2z; + d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2); + d3 += sh2_x * r2z; + d4 += sh2_x * r2x - sh2_y * r2y; + + // column 3 + real_t sh3_x = sh3 * r3x; + real_t sh3_y = sh3 * r3y; + d0 += sh3_x * r3y; + d1 += sh3_y * r3z; + d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2); + d3 += sh3_x * r3z; + d4 += sh3_x * r3x - sh3_y * r3y; + + // column 4 + real_t sh4_x = sh4 * r4x; + real_t sh4_y = sh4 * r4y; + d0 += sh4_x * r4y; + d1 += sh4_y * r4z; + d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2); + d3 += sh4_x * r4z; + d4 += sh4_x * r4x - sh4_y * r4y; + + // extra multipliers + p_values[4] = d0; + p_values[5] = -d1; + p_values[6] = d2 * s_scale_dst2; + p_values[7] = -d3; + p_values[8] = d4 * s_scale_dst4; +} diff --git a/core/math/basis.h b/core/math/basis.h index 0261cf67c6..2924a0ddbd 100644 --- a/core/math/basis.h +++ b/core/math/basis.h @@ -159,6 +159,7 @@ public: bool is_rotation() const; Basis slerp(const Basis &target, const real_t &t) const; + void rotate_sh(real_t *p_values); operator String() const; diff --git a/core/math/camera_matrix.cpp b/core/math/camera_matrix.cpp index 76321b0679..5d3ebc9f6d 100644 --- a/core/math/camera_matrix.cpp +++ b/core/math/camera_matrix.cpp @@ -33,6 +33,22 @@ #include "core/math/math_funcs.h" #include "core/print_string.h" +float CameraMatrix::determinant() const { + + return matrix[0][3] * matrix[1][2] * matrix[2][1] * matrix[3][0] - matrix[0][2] * matrix[1][3] * matrix[2][1] * matrix[3][0] - + matrix[0][3] * matrix[1][1] * matrix[2][2] * matrix[3][0] + matrix[0][1] * matrix[1][3] * matrix[2][2] * matrix[3][0] + + matrix[0][2] * matrix[1][1] * matrix[2][3] * matrix[3][0] - matrix[0][1] * matrix[1][2] * matrix[2][3] * matrix[3][0] - + matrix[0][3] * matrix[1][2] * matrix[2][0] * matrix[3][1] + matrix[0][2] * matrix[1][3] * matrix[2][0] * matrix[3][1] + + matrix[0][3] * matrix[1][0] * matrix[2][2] * matrix[3][1] - matrix[0][0] * matrix[1][3] * matrix[2][2] * matrix[3][1] - + matrix[0][2] * matrix[1][0] * matrix[2][3] * matrix[3][1] + matrix[0][0] * matrix[1][2] * matrix[2][3] * matrix[3][1] + + matrix[0][3] * matrix[1][1] * matrix[2][0] * matrix[3][2] - matrix[0][1] * matrix[1][3] * matrix[2][0] * matrix[3][2] - + matrix[0][3] * matrix[1][0] * matrix[2][1] * matrix[3][2] + matrix[0][0] * matrix[1][3] * matrix[2][1] * matrix[3][2] + + matrix[0][1] * matrix[1][0] * matrix[2][3] * matrix[3][2] - matrix[0][0] * matrix[1][1] * matrix[2][3] * matrix[3][2] - + matrix[0][2] * matrix[1][1] * matrix[2][0] * matrix[3][3] + matrix[0][1] * matrix[1][2] * matrix[2][0] * matrix[3][3] + + matrix[0][2] * matrix[1][0] * matrix[2][1] * matrix[3][3] - matrix[0][0] * matrix[1][2] * matrix[2][1] * matrix[3][3] - + matrix[0][1] * matrix[1][0] * matrix[2][2] * matrix[3][3] + matrix[0][0] * matrix[1][1] * matrix[2][2] * matrix[3][3]; +} + void CameraMatrix::set_identity() { for (int i = 0; i < 4; i++) { diff --git a/core/math/camera_matrix.h b/core/math/camera_matrix.h index c10193bc84..5420fa2984 100644 --- a/core/math/camera_matrix.h +++ b/core/math/camera_matrix.h @@ -47,6 +47,7 @@ struct CameraMatrix { real_t matrix[4][4]; + float determinant() const; void set_identity(); void set_zero(); void set_light_bias(); diff --git a/core/math/delaunay.h b/core/math/delaunay.h index 29f84210d2..6f19f3e58a 100644 --- a/core/math/delaunay.h +++ b/core/math/delaunay.h @@ -115,8 +115,6 @@ public: triangles.push_back(Triangle(p_points.size() + 0, p_points.size() + 1, p_points.size() + 2)); for (int i = 0; i < p_points.size(); i++) { - //std::cout << "Traitement du point " << *p << std::endl; - //std::cout << "_triangles contains " << _triangles.size() << " elements" << std::endl; Vector polygon; diff --git a/core/math/delaunay_3d.h b/core/math/delaunay_3d.h new file mode 100644 index 0000000000..6280ec8071 --- /dev/null +++ b/core/math/delaunay_3d.h @@ -0,0 +1,386 @@ +#ifndef DELAUNAY_3D_H +#define DELAUNAY_3D_H + +#include "core/local_vector.h" +#include "core/math/aabb.h" +#include "core/math/camera_matrix.h" +#include "core/math/vector3.h" +#include "core/oa_hash_map.h" +#include "core/os/file_access.h" +#include "core/print_string.h" +#include "core/variant.h" +#include "core/vector.h" +#include "thirdparty/r128/r128.h" + +class Delaunay3D { + struct Simplex; + + enum { + ACCEL_GRID_SIZE = 16 + }; + struct GridPos { + Vector3i pos; + List::Element *E = nullptr; + }; + + struct Simplex { + + uint32_t points[4]; + R128 circum_center_x; + R128 circum_center_y; + R128 circum_center_z; + R128 circum_r2; + LocalVector grid_positions; + List::Element *SE = nullptr; + + _FORCE_INLINE_ Simplex() {} + _FORCE_INLINE_ Simplex(uint32_t p_a, uint32_t p_b, uint32_t p_c, uint32_t p_d) { + points[0] = p_a; + points[1] = p_b; + points[2] = p_c; + points[3] = p_d; + } + }; + + struct Triangle { + uint32_t triangle[3]; + bool bad; + _FORCE_INLINE_ bool operator==(const Triangle &p_triangle) const { + return triangle[0] == p_triangle.triangle[0] && triangle[1] == p_triangle.triangle[1] && triangle[2] == p_triangle.triangle[2]; + } + + _FORCE_INLINE_ Triangle() { bad = false; } + _FORCE_INLINE_ Triangle(uint32_t p_a, uint32_t p_b, uint32_t p_c) { + if (p_a > p_b) + SWAP(p_a, p_b); + if (p_b > p_c) + SWAP(p_b, p_c); + if (p_a > p_b) + SWAP(p_a, p_b); + + bad = false; + triangle[0] = p_a; + triangle[1] = p_b; + triangle[2] = p_c; + } + }; + + struct TriangleHasher { + _FORCE_INLINE_ static uint32_t hash(const Triangle &p_triangle) { + uint32_t h = hash_djb2_one_32(p_triangle.triangle[0]); + h = hash_djb2_one_32(p_triangle.triangle[1], h); + return hash_djb2_one_32(p_triangle.triangle[2], h); + } + }; + + struct FPVal { + }; + + _FORCE_INLINE_ static void circum_sphere_compute(const Vector3 *p_points, Simplex *p_simplex) { + + // the only part in the algorithm where there may be precision errors is this one, so ensure that + // we do it as maximum precision as possible + + R128 v0_x = p_points[p_simplex->points[0]].x; + R128 v0_y = p_points[p_simplex->points[0]].y; + R128 v0_z = p_points[p_simplex->points[0]].z; + R128 v1_x = p_points[p_simplex->points[1]].x; + R128 v1_y = p_points[p_simplex->points[1]].y; + R128 v1_z = p_points[p_simplex->points[1]].z; + R128 v2_x = p_points[p_simplex->points[2]].x; + R128 v2_y = p_points[p_simplex->points[2]].y; + R128 v2_z = p_points[p_simplex->points[2]].z; + R128 v3_x = p_points[p_simplex->points[3]].x; + R128 v3_y = p_points[p_simplex->points[3]].y; + R128 v3_z = p_points[p_simplex->points[3]].z; + + //Create the rows of our "unrolled" 3x3 matrix + R128 row1_x = v1_x - v0_x; + R128 row1_y = v1_y - v0_y; + R128 row1_z = v1_z - v0_z; + + R128 row2_x = v2_x - v0_x; + R128 row2_y = v2_y - v0_y; + R128 row2_z = v2_z - v0_z; + + R128 row3_x = v3_x - v0_x; + R128 row3_y = v3_y - v0_y; + R128 row3_z = v3_z - v0_z; + + R128 sq_lenght1 = row1_x * row1_x + row1_y * row1_y + row1_z * row1_z; + R128 sq_lenght2 = row2_x * row2_x + row2_y * row2_y + row2_z * row2_z; + R128 sq_lenght3 = row3_x * row3_x + row3_y * row3_y + row3_z * row3_z; + + //Compute the determinant of said matrix + R128 determinant = row1_x * (row2_y * row3_z - row3_y * row2_z) - row2_x * (row1_y * row3_z - row3_y * row1_z) + row3_x * (row1_y * row2_z - row2_y * row1_z); + + // Compute the volume of the tetrahedron, and precompute a scalar quantity for re-use in the formula + R128 volume = determinant / R128(6.f); + R128 i12volume = R128(1.f) / (volume * R128(12.f)); + + R128 center_x = v0_x + i12volume * ((row2_y * row3_z - row3_y * row2_z) * sq_lenght1 - (row1_y * row3_z - row3_y * row1_z) * sq_lenght2 + (row1_y * row2_z - row2_y * row1_z) * sq_lenght3); + R128 center_y = v0_y + i12volume * (-(row2_x * row3_z - row3_x * row2_z) * sq_lenght1 + (row1_x * row3_z - row3_x * row1_z) * sq_lenght2 - (row1_x * row2_z - row2_x * row1_z) * sq_lenght3); + R128 center_z = v0_z + i12volume * ((row2_x * row3_y - row3_x * row2_y) * sq_lenght1 - (row1_x * row3_y - row3_x * row1_y) * sq_lenght2 + (row1_x * row2_y - row2_x * row1_y) * sq_lenght3); + + //Once we know the center, the radius is clearly the distance to any vertex + + R128 rel1_x = center_x - v0_x; + R128 rel1_y = center_y - v0_y; + R128 rel1_z = center_z - v0_z; + + R128 radius1 = rel1_x * rel1_x + rel1_y * rel1_y + rel1_z * rel1_z; + + p_simplex->circum_center_x = center_x; + p_simplex->circum_center_y = center_y; + p_simplex->circum_center_z = center_z; + p_simplex->circum_r2 = radius1; + } + + _FORCE_INLINE_ static bool simplex_contains(const Vector3 *p_points, const Simplex &p_simplex, uint32_t p_vertex) { + + R128 v_x = p_points[p_vertex].x; + R128 v_y = p_points[p_vertex].y; + R128 v_z = p_points[p_vertex].z; + + R128 rel2_x = p_simplex.circum_center_x - v_x; + R128 rel2_y = p_simplex.circum_center_y - v_y; + R128 rel2_z = p_simplex.circum_center_z - v_z; + + R128 radius2 = rel2_x * rel2_x + rel2_y * rel2_y + rel2_z * rel2_z; + + return radius2 < (p_simplex.circum_r2 - R128(0.00001)); + } + + static bool simplex_is_coplanar(const Vector3 *p_points, const Simplex &p_simplex) { + + Plane p(p_points[p_simplex.points[0]], p_points[p_simplex.points[1]], p_points[p_simplex.points[2]]); + if (ABS(p.distance_to(p_points[p_simplex.points[3]])) < CMP_EPSILON) { + return true; + } + + CameraMatrix cm; + + cm.matrix[0][0] = p_points[p_simplex.points[0]].x; + cm.matrix[0][1] = p_points[p_simplex.points[1]].x; + cm.matrix[0][2] = p_points[p_simplex.points[2]].x; + cm.matrix[0][3] = p_points[p_simplex.points[3]].x; + + cm.matrix[1][0] = p_points[p_simplex.points[0]].y; + cm.matrix[1][1] = p_points[p_simplex.points[1]].y; + cm.matrix[1][2] = p_points[p_simplex.points[2]].y; + cm.matrix[1][3] = p_points[p_simplex.points[3]].y; + + cm.matrix[2][0] = p_points[p_simplex.points[0]].z; + cm.matrix[2][1] = p_points[p_simplex.points[1]].z; + cm.matrix[2][2] = p_points[p_simplex.points[2]].z; + cm.matrix[2][3] = p_points[p_simplex.points[3]].z; + + cm.matrix[3][0] = 1.0; + cm.matrix[3][1] = 1.0; + cm.matrix[3][2] = 1.0; + cm.matrix[3][3] = 1.0; + + return ABS(cm.determinant()) <= CMP_EPSILON; + } + +public: + struct OutputSimplex { + uint32_t points[4]; + }; + + static Vector tetrahedralize(const Vector &p_points) { + + uint32_t point_count = p_points.size(); + Vector3 *points = (Vector3 *)memalloc(sizeof(Vector3) * (point_count + 4)); + + { + const Vector3 *src_points = p_points.ptr(); + AABB rect; + for (uint32_t i = 0; i < point_count; i++) { + Vector3 point = src_points[i]; + if (i == 0) { + rect.position = point; + } else { + rect.expand_to(point); + } + points[i] = point; + } + + for (uint32_t i = 0; i < point_count; i++) { + points[i] = (points[i] - rect.position) / rect.size; + } + + float delta_max = Math::sqrt(2.0) * 20.0; + Vector3 center = Vector3(0.5, 0.5, 0.5); + + // any simplex that contains everything is good + points[point_count + 0] = center + Vector3(0, 1, 0) * delta_max; + points[point_count + 1] = center + Vector3(0, -1, 1) * delta_max; + points[point_count + 2] = center + Vector3(1, -1, -1) * delta_max; + points[point_count + 3] = center + Vector3(-1, -1, -1) * delta_max; + } + + List acceleration_grid[ACCEL_GRID_SIZE][ACCEL_GRID_SIZE][ACCEL_GRID_SIZE]; + + List simplex_list; + { + //create root simplex + Simplex *root = memnew(Simplex(point_count + 0, point_count + 1, point_count + 2, point_count + 3)); + root->SE = simplex_list.push_back(root); + + for (uint32_t i = 0; i < ACCEL_GRID_SIZE; i++) { + for (uint32_t j = 0; j < ACCEL_GRID_SIZE; j++) { + for (uint32_t k = 0; k < ACCEL_GRID_SIZE; k++) { + GridPos gp; + gp.E = acceleration_grid[i][j][k].push_back(root); + gp.pos = Vector3i(i, j, k); + root->grid_positions.push_back(gp); + } + } + } + + circum_sphere_compute(points, root); + } + + OAHashMap triangles_inserted; + LocalVector triangles; + + for (uint32_t i = 0; i < point_count; i++) { + + bool unique = true; + for (uint32_t j = i + 1; j < point_count; j++) { + if (points[i].is_equal_approx(points[j])) { + unique = false; + break; + } + } + if (!unique) { + continue; + } + + Vector3i grid_pos = Vector3i(points[i] * ACCEL_GRID_SIZE); + grid_pos.x = CLAMP(grid_pos.x, 0, ACCEL_GRID_SIZE - 1); + grid_pos.y = CLAMP(grid_pos.y, 0, ACCEL_GRID_SIZE - 1); + grid_pos.z = CLAMP(grid_pos.z, 0, ACCEL_GRID_SIZE - 1); + + for (List::Element *E = acceleration_grid[grid_pos.x][grid_pos.y][grid_pos.z].front(); E;) { + List::Element *N = E->next(); //may be deleted + + Simplex *simplex = E->get(); + + if (simplex_contains(points, *simplex, i)) { + + static const uint32_t triangle_order[4][3] = { + { 0, 1, 2 }, + { 0, 1, 3 }, + { 0, 2, 3 }, + { 1, 2, 3 }, + }; + + for (uint32_t k = 0; k < 4; k++) { + Triangle t = Triangle(simplex->points[triangle_order[k][0]], simplex->points[triangle_order[k][1]], simplex->points[triangle_order[k][2]]); + uint32_t *p = triangles_inserted.lookup_ptr(t); + if (p) { + triangles[*p].bad = true; + } else { + triangles_inserted.insert(t, triangles.size()); + triangles.push_back(t); + } + } + + //remove simplex and continue + simplex_list.erase(simplex->SE); + + for (uint32_t k = 0; k < simplex->grid_positions.size(); k++) { + Vector3i p = simplex->grid_positions[k].pos; + acceleration_grid[p.x][p.y][p.z].erase(simplex->grid_positions[k].E); + } + memdelete(simplex); + } + E = N; + } + + uint32_t good_triangles = 0; + for (uint32_t j = 0; j < triangles.size(); j++) { + + if (triangles[j].bad) { + continue; + } + Simplex *new_simplex = memnew(Simplex(triangles[j].triangle[0], triangles[j].triangle[1], triangles[j].triangle[2], i)); + circum_sphere_compute(points, new_simplex); + new_simplex->SE = simplex_list.push_back(new_simplex); + { + Vector3 center; + center.x = double(new_simplex->circum_center_x); + center.y = double(new_simplex->circum_center_y); + center.z = double(new_simplex->circum_center_z); + + float radius2 = Math::sqrt(double(new_simplex->circum_r2)); + radius2 += 0.0001; // + Vector3 extents = Vector3(radius2, radius2, radius2); + Vector3i from = Vector3i((center - extents) * ACCEL_GRID_SIZE); + Vector3i to = Vector3i((center + extents) * ACCEL_GRID_SIZE); + from.x = CLAMP(from.x, 0, ACCEL_GRID_SIZE - 1); + from.y = CLAMP(from.y, 0, ACCEL_GRID_SIZE - 1); + from.z = CLAMP(from.z, 0, ACCEL_GRID_SIZE - 1); + to.x = CLAMP(to.x, 0, ACCEL_GRID_SIZE - 1); + to.y = CLAMP(to.y, 0, ACCEL_GRID_SIZE - 1); + to.z = CLAMP(to.z, 0, ACCEL_GRID_SIZE - 1); + + for (int32_t x = from.x; x <= to.x; x++) { + for (int32_t y = from.y; y <= to.y; y++) { + for (int32_t z = from.z; z <= to.z; z++) { + GridPos gp; + gp.pos = Vector3(x, y, z); + gp.E = acceleration_grid[x][y][z].push_back(new_simplex); + new_simplex->grid_positions.push_back(gp); + } + } + } + } + + good_triangles++; + } + + //print_line("at point " + itos(i) + "/" + itos(point_count) + " simplices added " + itos(good_triangles) + "/" + itos(simplex_list.size()) + " - triangles: " + itos(triangles.size())); + triangles.clear(); + triangles_inserted.clear(); + } + + //print_line("end with simplices: " + itos(simplex_list.size())); + Vector ret_simplices; + ret_simplices.resize(simplex_list.size()); + OutputSimplex *ret_simplicesw = ret_simplices.ptrw(); + uint32_t simplices_written = 0; + + for (List::Element *E = simplex_list.front(); E; E = E->next()) { + Simplex *simplex = E->get(); + bool invalid = false; + for (int j = 0; j < 4; j++) { + if (simplex->points[j] >= point_count) { + invalid = true; + break; + } + } + if (invalid || simplex_is_coplanar(points, *simplex)) { + memdelete(simplex); + continue; + } + + ret_simplicesw[simplices_written].points[0] = simplex->points[0]; + ret_simplicesw[simplices_written].points[1] = simplex->points[1]; + ret_simplicesw[simplices_written].points[2] = simplex->points[2]; + ret_simplicesw[simplices_written].points[3] = simplex->points[3]; + simplices_written++; + memdelete(simplex); + } + + ret_simplices.resize(simplices_written); + + memfree(points); + + return ret_simplices; + } +}; + +#endif // DELAUNAY_3D_H diff --git a/core/math/geometry.cpp b/core/math/geometry.cpp index e556eb3b9c..65b80856cc 100644 --- a/core/math/geometry.cpp +++ b/core/math/geometry.cpp @@ -33,6 +33,8 @@ #include "core/print_string.h" #include "thirdparty/misc/clipper.hpp" #include "thirdparty/misc/triangulator.h" +#define STB_RECT_PACK_IMPLEMENTATION +#include "thirdparty/stb_rect_pack/stb_rect_pack.h" #define SCALE_FACTOR 100000.0 // Based on CMP_EPSILON. @@ -1242,3 +1244,195 @@ Vector Geometry::compute_convex_mesh_points(const Plane *p_planes, int return points; } + +Vector Geometry::pack_rects(const Vector &p_sizes, const Size2i &p_atlas_size) { + + Vector nodes; + nodes.resize(p_atlas_size.width); + + stbrp_context context; + stbrp_init_target(&context, p_atlas_size.width, p_atlas_size.height, nodes.ptrw(), p_atlas_size.width); + + Vector rects; + rects.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + rects.write[i].id = 0; + rects.write[i].w = p_sizes[i].width; + rects.write[i].h = p_sizes[i].height; + rects.write[i].x = 0; + rects.write[i].y = 0; + rects.write[i].was_packed = 0; + } + + int res = stbrp_pack_rects(&context, rects.ptrw(), rects.size()); + if (res == 0) { //pack failed + return Vector(); + } + + Vector ret; + ret.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + Point2i r(rects[i].x, rects[i].y); + ret.write[i] = r; + } + + return ret; +} + +Vector Geometry::partial_pack_rects(const Vector &p_sizes, const Size2i &p_atlas_size) { + + Vector nodes; + nodes.resize(p_atlas_size.width); + zeromem(nodes.ptrw(), sizeof(stbrp_node) * nodes.size()); + + stbrp_context context; + stbrp_init_target(&context, p_atlas_size.width, p_atlas_size.height, nodes.ptrw(), p_atlas_size.width); + + Vector rects; + rects.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + rects.write[i].id = i; + rects.write[i].w = p_sizes[i].width; + rects.write[i].h = p_sizes[i].height; + rects.write[i].x = 0; + rects.write[i].y = 0; + rects.write[i].was_packed = 0; + } + + stbrp_pack_rects(&context, rects.ptrw(), rects.size()); + + Vector ret; + ret.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + ret.write[rects[i].id] = Vector3i(rects[i].x, rects[i].y, rects[i].was_packed != 0 ? 1 : 0); + } + + return ret; +} + +#define square(m_s) ((m_s) * (m_s)) +#define INF 1e20 + +/* dt of 1d function using squared distance */ +static void edt(float *f, int stride, int n) { + + float *d = (float *)alloca(sizeof(float) * n + sizeof(int) * n + sizeof(float) * (n + 1)); + int *v = (int *)&(d[n]); + float *z = (float *)&v[n]; + + int k = 0; + v[0] = 0; + z[0] = -INF; + z[1] = +INF; + for (int q = 1; q <= n - 1; q++) { + float s = ((f[q * stride] + square(q)) - (f[v[k] * stride] + square(v[k]))) / (2 * q - 2 * v[k]); + while (s <= z[k]) { + k--; + s = ((f[q * stride] + square(q)) - (f[v[k] * stride] + square(v[k]))) / (2 * q - 2 * v[k]); + } + k++; + v[k] = q; + + z[k] = s; + z[k + 1] = +INF; + } + + k = 0; + for (int q = 0; q <= n - 1; q++) { + while (z[k + 1] < q) + k++; + d[q] = square(q - v[k]) + f[v[k] * stride]; + } + + for (int i = 0; i < n; i++) { + f[i * stride] = d[i]; + } +} + +#undef square + +Vector Geometry::generate_edf(const Vector &p_voxels, const Vector3i &p_size, bool p_negative) { + + uint32_t float_count = p_size.x * p_size.y * p_size.z; + + ERR_FAIL_COND_V((uint32_t)p_voxels.size() != float_count, Vector()); + + float *work_memory = memnew_arr(float, float_count); + for (uint32_t i = 0; i < float_count; i++) { + work_memory[i] = INF; + } + + uint32_t y_mult = p_size.x; + uint32_t z_mult = y_mult * p_size.y; + + //plot solid cells + { + const bool *voxr = p_voxels.ptr(); + for (uint32_t i = 0; i < float_count; i++) { + + bool plot = voxr[i]; + if (p_negative) { + plot = !plot; + } + if (plot) { + work_memory[i] = 0; + } + } + } + + //process in each direction + + //xy->z + + for (int i = 0; i < p_size.x; i++) { + for (int j = 0; j < p_size.y; j++) { + edt(&work_memory[i + j * y_mult], z_mult, p_size.z); + } + } + + //xz->y + + for (int i = 0; i < p_size.x; i++) { + for (int j = 0; j < p_size.z; j++) { + edt(&work_memory[i + j * z_mult], y_mult, p_size.y); + } + } + + //yz->x + for (int i = 0; i < p_size.y; i++) { + for (int j = 0; j < p_size.z; j++) { + edt(&work_memory[i * y_mult + j * z_mult], 1, p_size.x); + } + } + + Vector ret; + ret.resize(float_count); + { + uint32_t *w = ret.ptrw(); + for (uint32_t i = 0; i < float_count; i++) { + w[i] = uint32_t(Math::sqrt(work_memory[i])); + } + } + + return ret; +} + +Vector Geometry::generate_sdf8(const Vector &p_positive, const Vector &p_negative) { + ERR_FAIL_COND_V(p_positive.size() != p_negative.size(), Vector()); + Vector sdf8; + int s = p_positive.size(); + sdf8.resize(s); + + const uint32_t *rpos = p_positive.ptr(); + const uint32_t *rneg = p_negative.ptr(); + int8_t *wsdf = sdf8.ptrw(); + for (int i = 0; i < s; i++) { + int32_t diff = int32_t(rpos[i]) - int32_t(rneg[i]); + wsdf[i] = CLAMP(diff, -128, 127); + } + return sdf8; +} diff --git a/core/math/geometry.h b/core/math/geometry.h index 3bbd1911ee..5a8e21d02b 100644 --- a/core/math/geometry.h +++ b/core/math/geometry.h @@ -1024,6 +1024,249 @@ public: static Vector compute_convex_mesh_points(const Plane *p_planes, int p_plane_count); +#define FINDMINMAX(x0, x1, x2, min, max) \ + min = max = x0; \ + if (x1 < min) \ + min = x1; \ + if (x1 > max) \ + max = x1; \ + if (x2 < min) \ + min = x2; \ + if (x2 > max) \ + max = x2; + + _FORCE_INLINE_ static bool planeBoxOverlap(Vector3 normal, float d, Vector3 maxbox) { + int q; + Vector3 vmin, vmax; + for (q = 0; q <= 2; q++) { + if (normal[q] > 0.0f) { + vmin[q] = -maxbox[q]; + vmax[q] = maxbox[q]; + } else { + vmin[q] = maxbox[q]; + vmax[q] = -maxbox[q]; + } + } + if (normal.dot(vmin) + d > 0.0f) + return false; + if (normal.dot(vmax) + d >= 0.0f) + return true; + + return false; + } + +/*======================== X-tests ========================*/ +#define AXISTEST_X01(a, b, fa, fb) \ + p0 = a * v0.y - b * v0.z; \ + p2 = a * v2.y - b * v2.z; \ + if (p0 < p2) { \ + min = p0; \ + max = p2; \ + } else { \ + min = p2; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + +#define AXISTEST_X2(a, b, fa, fb) \ + p0 = a * v0.y - b * v0.z; \ + p1 = a * v1.y - b * v1.z; \ + if (p0 < p1) { \ + min = p0; \ + max = p1; \ + } else { \ + min = p1; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + +/*======================== Y-tests ========================*/ +#define AXISTEST_Y02(a, b, fa, fb) \ + p0 = -a * v0.x + b * v0.z; \ + p2 = -a * v2.x + b * v2.z; \ + if (p0 < p2) { \ + min = p0; \ + max = p2; \ + } else { \ + min = p2; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + +#define AXISTEST_Y1(a, b, fa, fb) \ + p0 = -a * v0.x + b * v0.z; \ + p1 = -a * v1.x + b * v1.z; \ + if (p0 < p1) { \ + min = p0; \ + max = p1; \ + } else { \ + min = p1; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + + /*======================== Z-tests ========================*/ + +#define AXISTEST_Z12(a, b, fa, fb) \ + p1 = a * v1.x - b * v1.y; \ + p2 = a * v2.x - b * v2.y; \ + if (p2 < p1) { \ + min = p2; \ + max = p1; \ + } else { \ + min = p1; \ + max = p2; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ + if (min > rad || max < -rad) \ + return false; + +#define AXISTEST_Z0(a, b, fa, fb) \ + p0 = a * v0.x - b * v0.y; \ + p1 = a * v1.x - b * v1.y; \ + if (p0 < p1) { \ + min = p0; \ + max = p1; \ + } else { \ + min = p1; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ + if (min > rad || max < -rad) \ + return false; + + _FORCE_INLINE_ static bool triangle_box_overlap(const Vector3 &boxcenter, const Vector3 boxhalfsize, const Vector3 *triverts) { + + /* use separating axis theorem to test overlap between triangle and box */ + /* need to test for overlap in these directions: */ + /* 1) the {x,y,z}-directions (actually, since we use the AABB of the triangle */ + /* we do not even need to test these) */ + /* 2) normal of the triangle */ + /* 3) crossproduct(edge from tri, {x,y,z}-directin) */ + /* this gives 3x3=9 more tests */ + Vector3 v0, v1, v2; + float min, max, d, p0, p1, p2, rad, fex, fey, fez; + Vector3 normal, e0, e1, e2; + + /* This is the fastest branch on Sun */ + /* move everything so that the boxcenter is in (0,0,0) */ + + v0 = triverts[0] - boxcenter; + v1 = triverts[1] - boxcenter; + v2 = triverts[2] - boxcenter; + + /* compute triangle edges */ + e0 = v1 - v0; /* tri edge 0 */ + e1 = v2 - v1; /* tri edge 1 */ + e2 = v0 - v2; /* tri edge 2 */ + + /* Bullet 3: */ + /* test the 9 tests first (this was faster) */ + fex = Math::abs(e0.x); + fey = Math::abs(e0.y); + fez = Math::abs(e0.z); + AXISTEST_X01(e0.z, e0.y, fez, fey); + AXISTEST_Y02(e0.z, e0.x, fez, fex); + AXISTEST_Z12(e0.y, e0.x, fey, fex); + + fex = Math::abs(e1.x); + fey = Math::abs(e1.y); + fez = Math::abs(e1.z); + AXISTEST_X01(e1.z, e1.y, fez, fey); + AXISTEST_Y02(e1.z, e1.x, fez, fex); + AXISTEST_Z0(e1.y, e1.x, fey, fex); + + fex = Math::abs(e2.x); + fey = Math::abs(e2.y); + fez = Math::abs(e2.z); + AXISTEST_X2(e2.z, e2.y, fez, fey); + AXISTEST_Y1(e2.z, e2.x, fez, fex); + AXISTEST_Z12(e2.y, e2.x, fey, fex); + + /* Bullet 1: */ + /* first test overlap in the {x,y,z}-directions */ + /* find min, max of the triangle each direction, and test for overlap in */ + /* that direction -- this is equivalent to testing a minimal AABB around */ + /* the triangle against the AABB */ + + /* test in X-direction */ + FINDMINMAX(v0.x, v1.x, v2.x, min, max); + if (min > boxhalfsize.x || max < -boxhalfsize.x) + return false; + + /* test in Y-direction */ + FINDMINMAX(v0.y, v1.y, v2.y, min, max); + if (min > boxhalfsize.y || max < -boxhalfsize.y) + return false; + + /* test in Z-direction */ + FINDMINMAX(v0.z, v1.z, v2.z, min, max); + if (min > boxhalfsize.z || max < -boxhalfsize.z) + return false; + + /* Bullet 2: */ + /* test if the box intersects the plane of the triangle */ + /* compute plane equation of triangle: normal*x+d=0 */ + normal = e0.cross(e1); + d = -normal.dot(v0); /* plane eq: normal.x+d=0 */ + return planeBoxOverlap(normal, d, boxhalfsize); /* if true, box and triangle overlaps */ + } + + static Vector pack_rects(const Vector &p_sizes, const Size2i &p_atlas_size); + static Vector partial_pack_rects(const Vector &p_sizes, const Size2i &p_atlas_size); + + static Vector generate_edf(const Vector &p_voxels, const Vector3i &p_size, bool p_negative); + static Vector generate_sdf8(const Vector &p_positive, const Vector &p_negative); + + static Vector3 triangle_get_barycentric_coords(const Vector3 &p_a, const Vector3 &p_b, const Vector3 &p_c, const Vector3 &p_pos) { + Vector3 v0 = p_b - p_a; + Vector3 v1 = p_c - p_a; + Vector3 v2 = p_pos - p_a; + + float d00 = v0.dot(v0); + float d01 = v0.dot(v1); + float d11 = v1.dot(v1); + float d20 = v2.dot(v0); + float d21 = v2.dot(v1); + float denom = (d00 * d11 - d01 * d01); + if (denom == 0) { + return Vector3(); //invalid triangle, return empty + } + float v = (d11 * d20 - d01 * d21) / denom; + float w = (d00 * d21 - d01 * d20) / denom; + float u = 1.0f - v - w; + return Vector3(u, v, w); + } + + static Color tetrahedron_get_barycentric_coords(const Vector3 &p_a, const Vector3 &p_b, const Vector3 &p_c, const Vector3 &p_d, const Vector3 &p_pos) { + Vector3 vap = p_pos - p_a; + Vector3 vbp = p_pos - p_b; + + Vector3 vab = p_b - p_a; + Vector3 vac = p_c - p_a; + Vector3 vad = p_d - p_a; + + Vector3 vbc = p_c - p_b; + Vector3 vbd = p_d - p_b; + // ScTP computes the scalar triple product +#define STP(m_a, m_b, m_c) ((m_a).dot((m_b).cross((m_c)))) + float va6 = STP(vbp, vbd, vbc); + float vb6 = STP(vap, vac, vad); + float vc6 = STP(vap, vad, vab); + float vd6 = STP(vap, vab, vac); + float v6 = 1 / STP(vab, vac, vad); + return Color(va6 * v6, vb6 * v6, vc6 * v6, vd6 * v6); +#undef STP + } + private: static Vector> _polypaths_do_operation(PolyBooleanOperation p_op, const Vector &p_polypath_a, const Vector &p_polypath_b, bool is_a_open = false); static Vector> _polypath_offset(const Vector &p_polypath, real_t p_delta, PolyJoinType p_join_type, PolyEndType p_end_type); diff --git a/core/math/plane.cpp b/core/math/plane.cpp index a3818698bc..26ac0aac47 100644 --- a/core/math/plane.cpp +++ b/core/math/plane.cpp @@ -153,6 +153,10 @@ bool Plane::intersects_segment(const Vector3 &p_begin, const Vector3 &p_end, Vec /* misc */ +bool Plane::is_equal_approx_any_side(const Plane &p_plane) const { + return (normal.is_equal_approx(p_plane.normal) && Math::is_equal_approx(d, p_plane.d)) || (normal.is_equal_approx(-p_plane.normal) && Math::is_equal_approx(d, -p_plane.d)); +} + bool Plane::is_equal_approx(const Plane &p_plane) const { return normal.is_equal_approx(p_plane.normal) && Math::is_equal_approx(d, p_plane.d); diff --git a/core/math/plane.h b/core/math/plane.h index 771c8fc705..d4f23ff2b6 100644 --- a/core/math/plane.h +++ b/core/math/plane.h @@ -69,6 +69,7 @@ public: Plane operator-() const { return Plane(-normal, -d); } bool is_equal_approx(const Plane &p_plane) const; + bool is_equal_approx_any_side(const Plane &p_plane) const; _FORCE_INLINE_ bool operator==(const Plane &p_plane) const; _FORCE_INLINE_ bool operator!=(const Plane &p_plane) const; diff --git a/core/math/r128.cpp b/core/math/r128.cpp new file mode 100644 index 0000000000..fb1e4733ee --- /dev/null +++ b/core/math/r128.cpp @@ -0,0 +1,2 @@ +#define R128_IMPLEMENTATION +#include "thirdparty/r128/r128.h" diff --git a/core/ustring.cpp b/core/ustring.cpp index beafb3018d..3e8a1ddbe3 100644 --- a/core/ustring.cpp +++ b/core/ustring.cpp @@ -548,8 +548,8 @@ signed char String::naturalnocasecmp_to(const String &p_str) const { return -1; /* Compare the numbers */ - this_int = to_int(this_str); - that_int = to_int(that_str); + this_int = to_int(this_str, -1, true); + that_int = to_int(that_str, -1, true); if (this_int < that_int) return -1; @@ -2138,7 +2138,7 @@ double String::to_double(const CharType *p_str, const CharType **r_end) { return built_in_strtod(p_str, (CharType **)r_end); } -int64_t String::to_int(const CharType *p_str, int p_len) { +int64_t String::to_int(const CharType *p_str, int p_len, bool p_clamp) { if (p_len == 0 || !p_str[0]) return 0; @@ -2182,7 +2182,15 @@ int64_t String::to_int(const CharType *p_str, int p_len) { while (*str && str != limit) { number += *(str++); } - ERR_FAIL_V_MSG(sign == 1 ? INT64_MAX : INT64_MIN, "Cannot represent " + number + " as integer, provided value is " + (sign == 1 ? "too big." : "too small.")); + if (p_clamp) { + if (sign == 1) { + return INT64_MAX; + } else { + return INT64_MIN; + } + } else { + ERR_FAIL_V_MSG(sign == 1 ? INT64_MAX : INT64_MIN, "Cannot represent " + number + " as integer, provided value is " + (sign == 1 ? "too big." : "too small.")); + } } integer *= 10; integer += c - '0'; diff --git a/core/ustring.h b/core/ustring.h index ee7e3b1e16..15bc2b323c 100644 --- a/core/ustring.h +++ b/core/ustring.h @@ -254,7 +254,7 @@ public: static int to_int(const char *p_str, int p_len = -1); static double to_double(const char *p_str); static double to_double(const CharType *p_str, const CharType **r_end = nullptr); - static int64_t to_int(const CharType *p_str, int p_len = -1); + static int64_t to_int(const CharType *p_str, int p_len = -1, bool p_clamp = false); String capitalize() const; String camelcase_to_underscore(bool lowercase = true) const; diff --git a/core/vector.h b/core/vector.h index b2133f800b..74e0ab91c0 100644 --- a/core/vector.h +++ b/core/vector.h @@ -39,6 +39,7 @@ #include "core/cowdata.h" #include "core/error_macros.h" +#include "core/os/copymem.h" #include "core/os/memory.h" #include "core/sort_array.h" @@ -125,6 +126,13 @@ public: return *this; } + Vector to_byte_array() const { + Vector ret; + ret.resize(size() * sizeof(T)); + copymem(ret.ptrw(), ptr(), sizeof(T) * size()); + return ret; + } + Vector subarray(int p_from, int p_to) const { if (p_from < 0) { diff --git a/drivers/vulkan/rendering_device_vulkan.cpp b/drivers/vulkan/rendering_device_vulkan.cpp index 491dc4214f..71be891b1d 100644 --- a/drivers/vulkan/rendering_device_vulkan.cpp +++ b/drivers/vulkan/rendering_device_vulkan.cpp @@ -2453,7 +2453,7 @@ Vector RenderingDeviceVulkan::texture_get_data(RID p_texture, uint32_t uint32_t buffer_size = get_image_format_required_size(tex->format, tex->width, tex->height, tex->depth, tex->mipmaps, &width, &height, &depth); //allocate buffer - VkCommandBuffer command_buffer = frames[frame].setup_command_buffer; + VkCommandBuffer command_buffer = frames[frame].draw_command_buffer; //makes more sense to retrieve Buffer tmp_buffer; _buffer_allocate(&tmp_buffer, buffer_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, VMA_MEMORY_USAGE_CPU_ONLY); @@ -6859,6 +6859,7 @@ void RenderingDeviceVulkan::sync() { context->local_device_sync(local_device); _begin_frame(); + local_device_processing = false; } void RenderingDeviceVulkan::_free_pending_resources(int p_frame) { @@ -6975,6 +6976,12 @@ uint32_t RenderingDeviceVulkan::get_frame_delay() const { return frame_count; } +uint64_t RenderingDeviceVulkan::get_memory_usage() const { + VmaStats stats; + vmaCalculateStats(allocator, &stats); + return stats.total.usedBytes; +} + void RenderingDeviceVulkan::_flush(bool p_current_frame) { if (local_device.is_valid() && !p_current_frame) { @@ -7039,6 +7046,7 @@ void RenderingDeviceVulkan::initialize(VulkanContext *p_context, bool p_local_de if (p_local_device) { frame_count = 1; local_device = p_context->local_device_create(); + device = p_context->local_device_get_vk_device(local_device); } else { frame_count = p_context->get_swapchain_image_count() + 1; //always need one extra to ensure it's unused at any time, without having to use a fence for this. } diff --git a/drivers/vulkan/rendering_device_vulkan.h b/drivers/vulkan/rendering_device_vulkan.h index 6432946fbe..87af5d03d4 100644 --- a/drivers/vulkan/rendering_device_vulkan.h +++ b/drivers/vulkan/rendering_device_vulkan.h @@ -1138,6 +1138,8 @@ public: virtual RenderingDevice *create_local_device(); + virtual uint64_t get_memory_usage() const; + RenderingDeviceVulkan(); ~RenderingDeviceVulkan(); }; diff --git a/drivers/vulkan/vulkan_context.cpp b/drivers/vulkan/vulkan_context.cpp index a7eb3e53b9..9471b4604c 100644 --- a/drivers/vulkan/vulkan_context.cpp +++ b/drivers/vulkan/vulkan_context.cpp @@ -1567,6 +1567,15 @@ void VulkanContext::local_device_push_command_buffers(RID p_local_device, const submit_info.pSignalSemaphores = nullptr; VkResult err = vkQueueSubmit(ld->queue, 1, &submit_info, VK_NULL_HANDLE); + if (err == VK_ERROR_OUT_OF_HOST_MEMORY) { + print_line("out of host memory"); + } + if (err == VK_ERROR_OUT_OF_DEVICE_MEMORY) { + print_line("out of device memory"); + } + if (err == VK_ERROR_DEVICE_LOST) { + print_line("device lost"); + } ERR_FAIL_COND(err); ld->waiting = true; diff --git a/editor/editor_node.cpp b/editor/editor_node.cpp index c37ede4166..1b1ce4ec37 100644 --- a/editor/editor_node.cpp +++ b/editor/editor_node.cpp @@ -158,6 +158,7 @@ #include "editor/plugins/style_box_editor_plugin.h" #include "editor/plugins/text_editor.h" #include "editor/plugins/texture_editor_plugin.h" +#include "editor/plugins/texture_layered_editor_plugin.h" #include "editor/plugins/texture_region_editor_plugin.h" #include "editor/plugins/theme_editor_plugin.h" #include "editor/plugins/tile_map_editor_plugin.h" @@ -381,6 +382,8 @@ void EditorNode::_notification(int p_what) { RS::get_singleton()->shadows_quality_set(shadows_quality); RS::ShadowQuality directional_shadow_quality = RS::ShadowQuality(int(GLOBAL_GET("rendering/quality/directional_shadow/soft_shadow_quality"))); RS::get_singleton()->directional_shadow_quality_set(directional_shadow_quality); + float probe_update_speed = GLOBAL_GET("rendering/lightmapper/probe_capture_update_speed"); + RS::get_singleton()->lightmap_set_probe_capture_update_speed(probe_update_speed); } ResourceImporterTexture::get_singleton()->update_imports(); @@ -713,7 +716,6 @@ void EditorNode::_sources_changed(bool p_exist) { // Reload the global shader variables, but this time // loading texures, as they are now properly imported. - print_line("done scanning, reload textures"); RenderingServer::get_singleton()->global_variables_load_settings(true); // Start preview thread now that it's safe. @@ -5678,7 +5680,7 @@ EditorNode::EditorNode() { import_texture.instance(); ResourceFormatImporter::get_singleton()->add_importer(import_texture); - /* Ref import_cubemap; + Ref import_cubemap; import_cubemap.instance(); import_cubemap->set_mode(ResourceImporterLayeredTexture::MODE_CUBEMAP); ResourceFormatImporter::get_singleton()->add_importer(import_cubemap); @@ -5692,7 +5694,12 @@ EditorNode::EditorNode() { import_cubemap_array.instance(); import_cubemap_array->set_mode(ResourceImporterLayeredTexture::MODE_CUBEMAP_ARRAY); ResourceFormatImporter::get_singleton()->add_importer(import_cubemap_array); -*/ + + /*Ref import_3d; + import_3d.instance(); + import_3d->set_mode(ResourceImporterLayeredTexture::MODE_3D); + ResourceFormatImporter::get_singleton()->add_importer(import_3d);*/ + Ref import_image; import_image.instance(); ResourceFormatImporter::get_singleton()->add_importer(import_image); @@ -6663,7 +6670,7 @@ EditorNode::EditorNode() { add_editor_plugin(memnew(SpriteFramesEditorPlugin(this))); add_editor_plugin(memnew(TextureRegionEditorPlugin(this))); add_editor_plugin(memnew(GIProbeEditorPlugin(this))); - //add_editor_plugin(memnew(BakedLightmapEditorPlugin(this))); + add_editor_plugin(memnew(BakedLightmapEditorPlugin(this))); add_editor_plugin(memnew(Path2DEditorPlugin(this))); add_editor_plugin(memnew(Path3DEditorPlugin(this))); add_editor_plugin(memnew(Line2DEditorPlugin(this))); @@ -6674,6 +6681,7 @@ EditorNode::EditorNode() { add_editor_plugin(memnew(CollisionShape2DEditorPlugin(this))); add_editor_plugin(memnew(CurveEditorPlugin(this))); add_editor_plugin(memnew(TextureEditorPlugin(this))); + add_editor_plugin(memnew(TextureLayeredEditorPlugin(this))); add_editor_plugin(memnew(AudioStreamEditorPlugin(this))); add_editor_plugin(memnew(AudioBusesEditorPlugin(audio_bus_editor))); add_editor_plugin(memnew(Skeleton3DEditorPlugin(this))); diff --git a/editor/import/resource_importer_layered_texture.cpp b/editor/import/resource_importer_layered_texture.cpp index a4cbc81b26..c46cf4c1a8 100644 --- a/editor/import/resource_importer_layered_texture.cpp +++ b/editor/import/resource_importer_layered_texture.cpp @@ -36,9 +36,9 @@ #include "core/io/image_loader.h" #include "editor/editor_file_system.h" #include "editor/editor_node.h" +#include "resource_importer_texture.h" #include "scene/resources/texture.h" -#if 0 String ResourceImporterLayeredTexture::get_importer_name() const { switch (mode) { @@ -51,6 +51,9 @@ String ResourceImporterLayeredTexture::get_importer_name() const { case MODE_CUBEMAP_ARRAY: { return "cubemap_array_texture"; } break; + case MODE_3D: { + return "cubemap_3d_texture"; + } break; } ERR_FAIL_V(""); @@ -68,6 +71,9 @@ String ResourceImporterLayeredTexture::get_visible_name() const { case MODE_CUBEMAP_ARRAY: { return "CubemapArray"; } break; + case MODE_3D: { + return "3D"; + } break; } ERR_FAIL_V(""); @@ -79,13 +85,16 @@ void ResourceImporterLayeredTexture::get_recognized_extensions(List *p_e String ResourceImporterLayeredTexture::get_save_extension() const { switch (mode) { case MODE_CUBEMAP: { - return "cube"; + return "scube"; } break; case MODE_2D_ARRAY: { - return "tex2darr"; + return "stexarray"; } break; case MODE_CUBEMAP_ARRAY: { - return "cubearr"; + return "scubearray"; + } break; + case MODE_3D: { + return "stex3d"; } break; } @@ -96,13 +105,16 @@ String ResourceImporterLayeredTexture::get_resource_type() const { switch (mode) { case MODE_CUBEMAP: { - return "Cubemap"; + return "StreamCubemap"; } break; case MODE_2D_ARRAY: { - return "Texture2DArray"; + return "StreamTexture2DArray"; } break; case MODE_CUBEMAP_ARRAY: { - return "CubemapArray"; + return "StreamCubemapArray"; + } break; + case MODE_3D: { + return "StreamTexture3D"; } break; } ERR_FAIL_V(String()); @@ -110,6 +122,9 @@ String ResourceImporterLayeredTexture::get_resource_type() const { bool ResourceImporterLayeredTexture::get_option_visibility(const String &p_option, const Map &p_options) const { + if (p_option == "compress/lossy_quality" && p_options.has("compress/mode")) { + return int(p_options["compress/mode"]) == COMPRESS_LOSSY; + } return true; } @@ -123,138 +138,109 @@ String ResourceImporterLayeredTexture::get_preset_name(int p_idx) const { void ResourceImporterLayeredTexture::get_import_options(List *r_options, int p_preset) const { - r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/mode", PROPERTY_HINT_ENUM, "Lossless,Video RAM,Uncompressed", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_UPDATE_ALL_IF_MODIFIED), 1)); - r_options->push_back(ImportOption(PropertyInfo(Variant::BOOL, "compress/no_bptc_if_rgb"), false)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/mode", PROPERTY_HINT_ENUM, "Lossless,Lossy,Video RAM,Uncompressed,Basis Universal", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_UPDATE_ALL_IF_MODIFIED), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::FLOAT, "compress/lossy_quality", PROPERTY_HINT_RANGE, "0,1,0.01"), 0.7)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/hdr_compression", PROPERTY_HINT_ENUM, "Disabled,Opaque Only,Always"), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/bptc_ldr", PROPERTY_HINT_ENUM, "Disabled,Enabled,RGBA Only"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/channel_pack", PROPERTY_HINT_ENUM, "sRGB Friendly,Optimized"), 0)); - r_options->push_back(ImportOption(PropertyInfo(Variant::BOOL, "flags/mipmaps"), true)); - if (mode == MODE_2D_ARRAY) { + r_options->push_back(ImportOption(PropertyInfo(Variant::BOOL, "mipmaps/generate"), true)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "mipmaps/limit", PROPERTY_HINT_RANGE, "-1,256"), -1)); + + if (mode == MODE_2D_ARRAY || mode == MODE_3D) { r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/horizontal", PROPERTY_HINT_RANGE, "1,256,1"), 8)); - } - if (mode == MODE_2D_ARRAY || mode == MODE_CUBEMAP_ARRAY) { r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/vertical", PROPERTY_HINT_RANGE, "1,256,1"), 8)); } + if (mode == MODE_CUBEMAP || mode == MODE_CUBEMAP_ARRAY) { + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/arrangement", PROPERTY_HINT_ENUM, "1x6,2x3,3x2,6x1"), 1)); + if (mode == MODE_CUBEMAP_ARRAY) { + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/layout", PROPERTY_HINT_ENUM, "Horizontal,Vertical"), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/amount", PROPERTY_HINT_RANGE, "1,1024,1,or_greater"), 1)); + } + } } -void ResourceImporterLayeredTexture::_save_tex(const Vector > &p_images, const String &p_to_path, int p_compress_mode, Image::CompressMode p_vram_compression, bool p_mipmaps) { - - FileAccess *f = FileAccess::open(p_to_path, FileAccess::WRITE); - f->store_8('G'); - f->store_8('D'); - switch (mode) { - case MODE_2D_ARRAY: f->store_8('A'); break; - case MODE_CUBEMAP: f->store_8('C'); break; - case MODE_CUBEMAP_ARRAY: f->store_8('X'); break; - } - - f->store_8('T'); //godot streamable texture - - f->store_32(p_images[0]->get_width()); - f->store_32(p_images[0]->get_height()); - f->store_32(p_images.size()); //depth - uint32_t flags = 0; - if (p_mipmaps) { - flags |= TEXTURE_FLAGS_MIPMAPS; - } - f->store_32(flags); - if (p_compress_mode != COMPRESS_VIDEO_RAM) { - //vram needs to do a first compression to tell what the format is, for the rest its ok - f->store_32(p_images[0]->get_format()); - f->store_32(p_compress_mode); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed - } - - if ((p_compress_mode == COMPRESS_LOSSLESS) && p_images[0]->get_format() > Image::FORMAT_RGBA8) { - p_compress_mode = COMPRESS_UNCOMPRESSED; //these can't go as lossy - } +void ResourceImporterLayeredTexture::_save_tex(Vector> p_images, const String &p_to_path, int p_compress_mode, float p_lossy, Image::CompressMode p_vram_compression, Image::CompressSource p_csource, Image::UsedChannels used_channels, bool p_mipmaps, bool p_force_po2) { for (int i = 0; i < p_images.size(); i++) { - switch (p_compress_mode) { - case COMPRESS_LOSSLESS: { + if (p_force_po2) { + p_images.write[i]->resize_to_po2(); + } - Ref image = p_images[i]->duplicate(); - if (p_mipmaps) { - image->generate_mipmaps(); - } else { - image->clear_mipmaps(); - } - - int mmc = image->get_mipmap_count() + 1; - f->store_32(mmc); - - for (int j = 0; j < mmc; j++) { - - if (j > 0) { - image->shrink_x2(); - } - - Vector data = Image::lossless_packer(image); - int data_len = data.size(); - f->store_32(data_len); - - const uint8_t* r = data.ptr(); - f->store_buffer(r.ptr(), data_len); - } - - } break; - case COMPRESS_VIDEO_RAM: { - - Ref image = p_images[i]->duplicate(); - image->generate_mipmaps(false); - - Image::CompressSource csource = Image::COMPRESS_SOURCE_LAYERED; - image->compress(p_vram_compression, csource, 0.7); - - if (i == 0) { - //hack so we can properly tell the format - f->store_32(image->get_format()); - f->store_32(p_compress_mode); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed - } - - Vector data = image->get_data(); - int dl = data.size(); - - const uint8_t* r = data.ptr(); - f->store_buffer(r.ptr(), dl); - } break; - case COMPRESS_UNCOMPRESSED: { - - Ref image = p_images[i]->duplicate(); - - if (p_mipmaps) { - image->generate_mipmaps(); - } else { - image->clear_mipmaps(); - } - - Vector data = image->get_data(); - int dl = data.size(); - - const uint8_t* r = data.ptr(); - - f->store_buffer(r.ptr(), dl); - - } break; + if (p_mipmaps) { + p_images.write[i]->generate_mipmaps(); + } else { + p_images.write[i]->clear_mipmaps(); } } - memdelete(f); + FileAccessRef f = FileAccess::open(p_to_path, FileAccess::WRITE); + f->store_8('G'); + f->store_8('S'); + f->store_8('T'); + f->store_8('L'); + + f->store_32(StreamTextureLayered::FORMAT_VERSION); + f->store_32(p_images.size()); + f->store_32(mode); + f->store_32(0); //dataformat + f->store_32(0); //mipmap limit + + //reserverd + f->store_32(0); + f->store_32(0); + f->store_32(0); + + for (int i = 0; i < p_images.size(); i++) { + ResourceImporterTexture::save_to_stex_format(f, p_images[i], ResourceImporterTexture::CompressMode(p_compress_mode), used_channels, p_vram_compression, p_lossy); + } + + f->close(); } Error ResourceImporterLayeredTexture::import(const String &p_source_file, const String &p_save_path, const Map &p_options, List *r_platform_variants, List *r_gen_files, Variant *r_metadata) { int compress_mode = p_options["compress/mode"]; - int no_bptc_if_rgb = p_options["compress/no_bptc_if_rgb"]; - bool mipmaps = p_options["flags/mipmaps"]; + float lossy = p_options["compress/lossy_quality"]; + int hdr_compression = p_options["compress/hdr_compression"]; + int bptc_ldr = p_options["compress/bptc_ldr"]; + bool mipmaps = p_options["mipmaps/generate"]; + //bool mipmap_limit = p_options["mipmaps/limit"]; + int channel_pack = p_options["compress/channel_pack"]; int hslices = (p_options.has("slices/horizontal")) ? int(p_options["slices/horizontal"]) : 0; int vslices = (p_options.has("slices/vertical")) ? int(p_options["slices/vertical"]) : 0; + int arrangement = (p_options.has("slices/arrangement")) ? int(p_options["slices/arrangement"]) : 0; + int layout = (p_options.has("slices/layout")) ? int(p_options["slices/layout"]) : 0; + int amount = (p_options.has("slices/amount")) ? int(p_options["slices/amount"]) : 0; - if (mode == MODE_CUBEMAP) { - hslices = 3; - vslices = 2; - } else if (mode == MODE_CUBEMAP_ARRAY) { - hslices = 3; - vslices *= 2; //put cubemaps vertically + if (mode == MODE_CUBEMAP || mode == MODE_CUBEMAP_ARRAY) { + switch (arrangement) { + case CUBEMAP_FORMAT_1X6: { + hslices = 1; + vslices = 6; + } break; + case CUBEMAP_FORMAT_2X3: { + hslices = 2; + vslices = 3; + } break; + case CUBEMAP_FORMAT_3X2: { + hslices = 3; + vslices = 2; + } break; + case CUBEMAP_FORMAT_6X1: { + hslices = 6; + vslices = 1; + } break; + } + + if (mode == MODE_CUBEMAP_ARRAY) { + if (layout == 0) { + hslices *= amount; + } else { + vslices *= amount; + } + } } Ref image; @@ -263,28 +249,40 @@ Error ResourceImporterLayeredTexture::import(const String &p_source_file, const if (err != OK) return err; - if (compress_mode == COMPRESS_VIDEO_RAM) { + if (compress_mode == COMPRESS_BASIS_UNIVERSAL && image->get_format() >= Image::FORMAT_RF) { + //basis universal does not support float formats, fall back + compress_mode = COMPRESS_VRAM_COMPRESSED; + } + + if (compress_mode == COMPRESS_VRAM_COMPRESSED) { mipmaps = true; } - Vector > slices; - - int slice_w = image->get_width() / hslices; - int slice_h = image->get_height() / vslices; - //optimize - if (compress_mode == COMPRESS_VIDEO_RAM) { + if (compress_mode == COMPRESS_VRAM_COMPRESSED) { //if using video ram, optimize if (channel_pack == 0) { //remove alpha if not needed, so compression is more efficient if (image->get_format() == Image::FORMAT_RGBA8 && !image->detect_alpha()) { image->convert(Image::FORMAT_RGB8); } - } else { + } else if (image->get_format() < Image::FORMAT_RGBA8) { image->optimize_channels(); } } + Image::CompressSource csource = Image::COMPRESS_SOURCE_GENERIC; + if (channel_pack == 0) { + csource = Image::COMPRESS_SOURCE_SRGB; + } + + Image::UsedChannels used_channels = image->detect_used_channels(csource); + + Vector> slices; + + int slice_w = image->get_width() / hslices; + int slice_h = image->get_height() / vslices; + for (int i = 0; i < vslices; i++) { for (int j = 0; j < hslices; j++) { int x = slice_w * j; @@ -301,58 +299,82 @@ Error ResourceImporterLayeredTexture::import(const String &p_source_file, const String extension = get_save_extension(); Array formats_imported; - if (compress_mode == COMPRESS_VIDEO_RAM) { + if (compress_mode == COMPRESS_VRAM_COMPRESSED) { //must import in all formats, in order of priority (so platform choses the best supported one. IE, etc2 over etc). //Android, GLES 2.x bool ok_on_pc = false; - bool encode_bptc = false; + bool is_hdr = (image->get_format() >= Image::FORMAT_RF && image->get_format() <= Image::FORMAT_RGBE9995); + bool is_ldr = (image->get_format() >= Image::FORMAT_L8 && image->get_format() <= Image::FORMAT_RGB565); + bool can_bptc = ProjectSettings::get_singleton()->get("rendering/vram_compression/import_bptc"); + bool can_s3tc = ProjectSettings::get_singleton()->get("rendering/vram_compression/import_s3tc"); - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_bptc")) { + if (can_bptc) { + formats_imported.push_back("bptc"); //needs to be aded anyway + } + bool can_compress_hdr = hdr_compression > 0; - encode_bptc = true; + if (is_hdr && can_compress_hdr) { - if (no_bptc_if_rgb) { - Image::UsedChannels channels = image->detect_used_channels(); - if (channels != Image::USED_CHANNELS_LA && channels != Image::USED_CHANNELS_RGBA) { - encode_bptc = false; + if (used_channels == Image::USED_CHANNELS_LA || used_channels == Image::USED_CHANNELS_RGBA) { + //can compress hdr, but hdr with alpha is not compressible + + if (hdr_compression == 2) { + //but user selected to compress hdr anyway, so force an alpha-less format. + if (image->get_format() == Image::FORMAT_RGBAF) { + for (int i = 0; i < slices.size(); i++) { + slices.write[i]->convert(Image::FORMAT_RGBF); + } + + } else if (image->get_format() == Image::FORMAT_RGBAH) { + for (int i = 0; i < slices.size(); i++) { + slices.write[i]->convert(Image::FORMAT_RGBH); + } + } + } else { + can_compress_hdr = false; } } - formats_imported.push_back("bptc"); + if (can_compress_hdr) { + + if (!can_bptc) { + + //default to rgbe + if (image->get_format() != Image::FORMAT_RGBE9995) { + for (int i = 0; i < slices.size(); i++) { + slices.write[i]->convert(Image::FORMAT_RGBE9995); + } + } + } + } else { + can_bptc = false; + } } - if (encode_bptc) { - - _save_tex(slices, p_save_path + ".bptc." + extension, compress_mode, Image::COMPRESS_BPTC, mipmaps); - r_platform_variants->push_back("bptc"); - ok_on_pc = true; + if (is_ldr && can_bptc) { + if (bptc_ldr == 0 || (bptc_ldr == 1 && !(used_channels == Image::USED_CHANNELS_LA || used_channels == Image::USED_CHANNELS_RGBA))) { + can_bptc = false; + } } - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_s3tc")) { - - _save_tex(slices, p_save_path + ".s3tc." + extension, compress_mode, Image::COMPRESS_S3TC, mipmaps); + if (can_bptc || can_s3tc) { + _save_tex(slices, p_save_path + ".s3tc." + extension, compress_mode, lossy, can_bptc ? Image::COMPRESS_BPTC : Image::COMPRESS_S3TC, csource, used_channels, mipmaps, false); r_platform_variants->push_back("s3tc"); - ok_on_pc = true; formats_imported.push_back("s3tc"); + ok_on_pc = true; } if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc2")) { - _save_tex(slices, p_save_path + ".etc2." + extension, compress_mode, Image::COMPRESS_ETC2, mipmaps); + _save_tex(slices, p_save_path + ".etc2." + extension, compress_mode, lossy, Image::COMPRESS_ETC2, csource, used_channels, mipmaps, true); r_platform_variants->push_back("etc2"); formats_imported.push_back("etc2"); } - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc")) { - _save_tex(slices, p_save_path + ".etc." + extension, compress_mode, Image::COMPRESS_ETC, mipmaps); - r_platform_variants->push_back("etc"); - formats_imported.push_back("etc"); - } - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_pvrtc")) { - _save_tex(slices, p_save_path + ".pvrtc." + extension, compress_mode, Image::COMPRESS_PVRTC4, mipmaps); + _save_tex(slices, p_save_path + ".etc2." + extension, compress_mode, lossy, Image::COMPRESS_ETC2, csource, used_channels, mipmaps, true); r_platform_variants->push_back("pvrtc"); formats_imported.push_back("pvrtc"); } @@ -362,12 +384,12 @@ Error ResourceImporterLayeredTexture::import(const String &p_source_file, const } } else { //import normally - _save_tex(slices, p_save_path + "." + extension, compress_mode, Image::COMPRESS_S3TC /*this is ignored */, mipmaps); + _save_tex(slices, p_save_path + "." + extension, compress_mode, lossy, Image::COMPRESS_S3TC /* IGNORED */, csource, used_channels, mipmaps, false); } if (r_metadata) { Dictionary metadata; - metadata["vram_texture"] = compress_mode == COMPRESS_VIDEO_RAM; + metadata["vram_texture"] = compress_mode == COMPRESS_VRAM_COMPRESSED; if (formats_imported.size()) { metadata["imported_formats"] = formats_imported; } @@ -448,4 +470,3 @@ ResourceImporterLayeredTexture::ResourceImporterLayeredTexture() { ResourceImporterLayeredTexture::~ResourceImporterLayeredTexture() { } -#endif diff --git a/editor/import/resource_importer_layered_texture.h b/editor/import/resource_importer_layered_texture.h index 40e5c9023e..18eaf31f6b 100644 --- a/editor/import/resource_importer_layered_texture.h +++ b/editor/import/resource_importer_layered_texture.h @@ -28,7 +28,6 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 /*************************************************************************/ /* resource_importer_layered_texture.h */ /*************************************************************************/ @@ -65,16 +64,24 @@ #include "core/image.h" #include "core/io/resource_importer.h" - -class StreamTexture; +class StreamTexture2D; class ResourceImporterLayeredTexture : public ResourceImporter { GDCLASS(ResourceImporterLayeredTexture, ResourceImporter); + public: enum Mode { - MODE_CUBEMAP, MODE_2D_ARRAY, - MODE_CUBEMAP_ARRAY + MODE_CUBEMAP, + MODE_CUBEMAP_ARRAY, + MODE_3D, + }; + + enum CubemapFormat { + CUBEMAP_FORMAT_1X6, + CUBEMAP_FORMAT_2X3, + CUBEMAP_FORMAT_3X2, + CUBEMAP_FORMAT_6X1, }; enum TextureFlags { @@ -86,9 +93,9 @@ private: static const char *compression_formats[]; protected: - static void _texture_reimport_srgb(const Ref &p_tex); - static void _texture_reimport_3d(const Ref &p_tex); - static void _texture_reimport_normal(const Ref &p_tex); + static void _texture_reimport_srgb(const Ref &p_tex); + static void _texture_reimport_3d(const Ref &p_tex); + static void _texture_reimport_normal(const Ref &p_tex); static ResourceImporterLayeredTexture *singleton; @@ -102,8 +109,10 @@ public: enum CompressMode { COMPRESS_LOSSLESS, - COMPRESS_VIDEO_RAM, - COMPRESS_UNCOMPRESSED + COMPRESS_LOSSY, + COMPRESS_VRAM_COMPRESSED, + COMPRESS_VRAM_UNCOMPRESSED, + COMPRESS_BASIS_UNIVERSAL }; virtual int get_preset_count() const; @@ -112,7 +121,7 @@ public: virtual void get_import_options(List *r_options, int p_preset = 0) const; virtual bool get_option_visibility(const String &p_option, const Map &p_options) const; - void _save_tex(const Vector > &p_images, const String &p_to_path, int p_compress_mode, Image::CompressMode p_vram_compression, bool p_mipmaps); + void _save_tex(Vector> p_images, const String &p_to_path, int p_compress_mode, float p_lossy, Image::CompressMode p_vram_compression, Image::CompressSource p_csource, Image::UsedChannels used_channels, bool p_mipmaps, bool p_force_po2); virtual Error import(const String &p_source_file, const String &p_save_path, const Map &p_options, List *r_platform_variants, List *r_gen_files = nullptr, Variant *r_metadata = nullptr); @@ -126,6 +135,5 @@ public: ResourceImporterLayeredTexture(); ~ResourceImporterLayeredTexture(); }; -#endif // RESOURCE_IMPORTER_LAYERED_TEXTURE_H -#endif +#endif // RESOURCE_IMPORTER_LAYERED_TEXTURE_H diff --git a/editor/import/resource_importer_scene.cpp b/editor/import/resource_importer_scene.cpp index e7f87acd03..b9341b1fb6 100644 --- a/editor/import/resource_importer_scene.cpp +++ b/editor/import/resource_importer_scene.cpp @@ -354,7 +354,7 @@ Node *ResourceImporterScene::_fix_node(Node *p_node, Node *p_root, Map if (p_light_bake_mode != LIGHT_BAKE_DISABLED) { - mi->set_flag(GeometryInstance3D::FLAG_USE_BAKED_LIGHT, true); + mi->set_gi_mode(GeometryInstance3D::GI_MODE_BAKED); } } @@ -955,7 +955,7 @@ void ResourceImporterScene::_find_meshes(Node *p_node, Map, Trans Transform transform; while (s) { transform = transform * s->get_transform(); - s = s->get_parent_spatial(); + s = Object::cast_to(s->get_parent()); } meshes[mesh] = transform; @@ -1358,8 +1358,9 @@ Error ResourceImporterScene::import(const String &p_source_file, const String &p scene->set_script(Variant(root_script)); } + float root_scale = 1.0; if (Object::cast_to(scene)) { - float root_scale = p_options["nodes/root_scale"]; + root_scale = p_options["nodes/root_scale"]; Object::cast_to(scene)->scale(Vector3(root_scale, root_scale, root_scale)); } diff --git a/editor/import/resource_importer_shader_file.cpp b/editor/import/resource_importer_shader_file.cpp index 3a6215e035..f085341969 100644 --- a/editor/import/resource_importer_shader_file.cpp +++ b/editor/import/resource_importer_shader_file.cpp @@ -103,7 +103,7 @@ Error ResourceImporterShaderFile::import(const String &p_source_file, const Stri Ref shader_file; shader_file.instance(); String base_path = p_source_file.get_base_dir(); - err = shader_file->parse_versions_from_text(file_txt, _include_function, &base_path); + err = shader_file->parse_versions_from_text(file_txt, "", _include_function, &base_path); if (err != OK) { if (!ShaderFileEditor::singleton->is_visible_in_tree()) { diff --git a/editor/import/resource_importer_texture.cpp b/editor/import/resource_importer_texture.cpp index f8ed9304b6..111eab78b4 100644 --- a/editor/import/resource_importer_texture.cpp +++ b/editor/import/resource_importer_texture.cpp @@ -36,7 +36,7 @@ #include "editor/editor_file_system.h" #include "editor/editor_node.h" -void ResourceImporterTexture::_texture_reimport_roughness(const Ref &p_tex, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_channel) { +void ResourceImporterTexture::_texture_reimport_roughness(const Ref &p_tex, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_channel) { MutexLock lock(singleton->mutex); @@ -51,7 +51,7 @@ void ResourceImporterTexture::_texture_reimport_roughness(const Refmake_flags[path].normal_path_for_roughness = p_normal_path; } -void ResourceImporterTexture::_texture_reimport_3d(const Ref &p_tex) { +void ResourceImporterTexture::_texture_reimport_3d(const Ref &p_tex) { MutexLock lock(singleton->mutex); @@ -64,7 +64,7 @@ void ResourceImporterTexture::_texture_reimport_3d(const Ref &p_t singleton->make_flags[path].flags |= MAKE_3D_FLAG; } -void ResourceImporterTexture::_texture_reimport_normal(const Ref &p_tex) { +void ResourceImporterTexture::_texture_reimport_normal(const Ref &p_tex) { MutexLock lock(singleton->mutex); @@ -157,7 +157,7 @@ String ResourceImporterTexture::get_save_extension() const { String ResourceImporterTexture::get_resource_type() const { - return "StreamTexture"; + return "StreamTexture2D"; } bool ResourceImporterTexture::get_option_visibility(const String &p_option, const Map &p_options) const { @@ -207,8 +207,8 @@ void ResourceImporterTexture::get_import_options(List *r_options, r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/mode", PROPERTY_HINT_ENUM, "Lossless,Lossy,VRAM Compressed,VRAM Uncompressed,Basis Universal", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_UPDATE_ALL_IF_MODIFIED), p_preset == PRESET_3D ? 2 : 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::FLOAT, "compress/lossy_quality", PROPERTY_HINT_RANGE, "0,1,0.01"), 0.7)); - r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/hdr_mode", PROPERTY_HINT_ENUM, "Enabled,Force RGBE"), 0)); - r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/bptc_ldr", PROPERTY_HINT_ENUM, "Enabled,RGBA Only"), 0)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/hdr_compression", PROPERTY_HINT_ENUM, "Disabled,Opaque Only,Always"), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/bptc_ldr", PROPERTY_HINT_ENUM, "Disabled,Enabled,RGBA Only"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/normal_map", PROPERTY_HINT_ENUM, "Detect,Enable,Disabled"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/channel_pack", PROPERTY_HINT_ENUM, "sRGB Friendly,Optimized"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/streamed"), false)); @@ -225,12 +225,12 @@ void ResourceImporterTexture::get_import_options(List *r_options, r_options->push_back(ImportOption(PropertyInfo(Variant::FLOAT, "svg/scale", PROPERTY_HINT_RANGE, "0.001,100,0.001"), 1.0)); } -void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality, bool p_force_rgbe) { +void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality) { switch (p_compress_mode) { case COMPRESS_LOSSLESS: { - f->store_32(StreamTexture::DATA_FORMAT_LOSSLESS); + f->store_32(StreamTexture2D::DATA_FORMAT_LOSSLESS); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -249,7 +249,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Refstore_32(StreamTexture::DATA_FORMAT_LOSSY); + f->store_32(StreamTexture2D::DATA_FORMAT_LOSSY); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -269,13 +269,9 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref image = p_image->duplicate(); - if (p_force_rgbe && image->get_format() >= Image::FORMAT_RF && image->get_format() < Image::FORMAT_RGBE9995) { - image->convert(Image::FORMAT_RGBE9995); - } else { - image->compress_from_channels(p_compress_format, p_channels, p_lossy_quality); - } + image->compress_from_channels(p_compress_format, p_channels, p_lossy_quality); - f->store_32(StreamTexture::DATA_FORMAT_IMAGE); + f->store_32(StreamTexture2D::DATA_FORMAT_IMAGE); f->store_16(image->get_width()); f->store_16(image->get_height()); f->store_32(image->get_mipmap_count()); @@ -288,7 +284,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Refstore_32(StreamTexture::DATA_FORMAT_IMAGE); + f->store_32(StreamTexture2D::DATA_FORMAT_IMAGE); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -303,7 +299,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Refstore_32(StreamTexture::DATA_FORMAT_BASIS_UNIVERSAL); + f->store_32(StreamTexture2D::DATA_FORMAT_BASIS_UNIVERSAL); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -322,7 +318,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_roughness, bool p_force_rgbe, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref &p_normal, Image::RoughnessChannel p_roughness_channel) { +void ResourceImporterTexture::_save_stex(const Ref &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_roughness, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref &p_normal, Image::RoughnessChannel p_roughness_channel) { FileAccess *f = FileAccess::open(p_to_path, FileAccess::WRITE); f->store_8('G'); @@ -331,22 +327,22 @@ void ResourceImporterTexture::_save_stex(const Ref &p_image, const String f->store_8('2'); //godot streamable texture 2D //format version - f->store_32(StreamTexture::FORMAT_VERSION); + f->store_32(StreamTexture2D::FORMAT_VERSION); //texture may be resized later, so original size must be saved first f->store_32(p_image->get_width()); f->store_32(p_image->get_height()); uint32_t flags = 0; if (p_streamable) - flags |= StreamTexture::FORMAT_BIT_STREAM; + flags |= StreamTexture2D::FORMAT_BIT_STREAM; if (p_mipmaps) - flags |= StreamTexture::FORMAT_BIT_HAS_MIPMAPS; //mipmaps bit + flags |= StreamTexture2D::FORMAT_BIT_HAS_MIPMAPS; //mipmaps bit if (p_detect_3d) - flags |= StreamTexture::FORMAT_BIT_DETECT_3D; + flags |= StreamTexture2D::FORMAT_BIT_DETECT_3D; if (p_detect_roughness) - flags |= StreamTexture::FORMAT_BIT_DETECT_ROUGNESS; + flags |= StreamTexture2D::FORMAT_BIT_DETECT_ROUGNESS; if (p_detect_normal) - flags |= StreamTexture::FORMAT_BIT_DETECT_NORMAL; + flags |= StreamTexture2D::FORMAT_BIT_DETECT_NORMAL; f->store_32(flags); f->store_32(p_limit_mipmap); @@ -385,10 +381,6 @@ void ResourceImporterTexture::_save_stex(const Ref &p_image, const String image->generate_mipmap_roughness(p_roughness_channel, p_normal); } - if (p_force_rgbe && image->get_format() >= Image::FORMAT_RF && image->get_format() < Image::FORMAT_RGBE9995) { - image->convert(Image::FORMAT_RGBE9995); - } - Image::CompressSource csource = Image::COMPRESS_SOURCE_GENERIC; if (p_force_normal) { csource = Image::COMPRESS_SOURCE_NORMAL; @@ -398,7 +390,7 @@ void ResourceImporterTexture::_save_stex(const Ref &p_image, const String Image::UsedChannels used_channels = image->detect_used_channels(csource); - save_to_stex_format(f, image, p_compress_mode, used_channels, p_vram_compression, p_lossy_quality, p_force_rgbe); + save_to_stex_format(f, image, p_compress_mode, used_channels, p_vram_compression, p_lossy_quality); memdelete(f); } @@ -418,7 +410,7 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String bool hdr_as_srgb = p_options["process/HDR_as_SRGB"]; int normal = p_options["compress/normal_map"]; float scale = p_options["svg/scale"]; - bool force_rgbe = int(p_options["compress/hdr_mode"]) == 1; + int hdr_compression = p_options["compress/hdr_compression"]; int bptc_ldr = p_options["compress/bptc_ldr"]; int roughness = p_options["roughness/mode"]; String normal_map = p_options["roughness/src_normal"]; @@ -501,30 +493,49 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String bool can_s3tc = ProjectSettings::get_singleton()->get("rendering/vram_compression/import_s3tc"); if (can_bptc) { - Image::UsedChannels channels = image->detect_used_channels(); - if (is_hdr) { - - if (channels == Image::USED_CHANNELS_LA || channels == Image::USED_CHANNELS_RGBA) { - can_bptc = false; - } - } else if (is_ldr) { - - //handle "RGBA Only" setting - if (bptc_ldr == 1 && channels != Image::USED_CHANNELS_LA && channels != Image::USED_CHANNELS_RGBA) { - can_bptc = false; - } - } - + //add to the list anyway formats_imported.push_back("bptc"); } - if (!can_bptc && is_hdr && !force_rgbe) { - //convert to ldr if this can't be stored hdr - image->convert(Image::FORMAT_RGBA8); + bool can_compress_hdr = hdr_compression > 0; + bool has_alpha = image->detect_alpha() != Image::ALPHA_NONE; + + if (is_hdr && can_compress_hdr) { + + if (has_alpha) { + //can compress hdr, but hdr with alpha is not compressible + if (hdr_compression == 2) { + //but user selected to compress hdr anyway, so force an alpha-less format. + if (image->get_format() == Image::FORMAT_RGBAF) { + image->convert(Image::FORMAT_RGBF); + } else if (image->get_format() == Image::FORMAT_RGBAH) { + image->convert(Image::FORMAT_RGBH); + } + } else { + can_compress_hdr = false; + } + } + + if (can_compress_hdr) { + if (!can_bptc) { + //fallback to RGBE99995 + if (image->get_format() != Image::FORMAT_RGBE9995) { + image->convert(Image::FORMAT_RGBE9995); + } + } + } else { + can_bptc = false; + } + } + + if (is_ldr && can_bptc) { + if (bptc_ldr == 0 || (bptc_ldr == 1 && !has_alpha)) { + can_bptc = false; + } } if (can_bptc || can_s3tc) { - _save_stex(image, p_save_path + ".s3tc.stex", compress_mode, lossy, can_bptc ? Image::COMPRESS_BPTC : Image::COMPRESS_S3TC, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".s3tc.stex", compress_mode, lossy, can_bptc ? Image::COMPRESS_BPTC : Image::COMPRESS_S3TC, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("s3tc"); formats_imported.push_back("s3tc"); ok_on_pc = true; @@ -532,20 +543,20 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc2")) { - _save_stex(image, p_save_path + ".etc2.stex", compress_mode, lossy, Image::COMPRESS_ETC2, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".etc2.stex", compress_mode, lossy, Image::COMPRESS_ETC2, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("etc2"); formats_imported.push_back("etc2"); } if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc")) { - _save_stex(image, p_save_path + ".etc.stex", compress_mode, lossy, Image::COMPRESS_ETC, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".etc.stex", compress_mode, lossy, Image::COMPRESS_ETC, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("etc"); formats_imported.push_back("etc"); } if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_pvrtc")) { - _save_stex(image, p_save_path + ".pvrtc.stex", compress_mode, lossy, Image::COMPRESS_PVRTC4, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".pvrtc.stex", compress_mode, lossy, Image::COMPRESS_PVRTC4, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("pvrtc"); formats_imported.push_back("pvrtc"); } @@ -555,7 +566,7 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String } } else { //import normally - _save_stex(image, p_save_path + ".stex", compress_mode, lossy, Image::COMPRESS_S3TC /*this is ignored */, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".stex", compress_mode, lossy, Image::COMPRESS_S3TC /*this is ignored */, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); } if (r_metadata) { @@ -635,9 +646,9 @@ ResourceImporterTexture *ResourceImporterTexture::singleton = nullptr; ResourceImporterTexture::ResourceImporterTexture() { singleton = this; - StreamTexture::request_3d_callback = _texture_reimport_3d; - StreamTexture::request_roughness_callback = _texture_reimport_roughness; - StreamTexture::request_normal_callback = _texture_reimport_normal; + StreamTexture2D::request_3d_callback = _texture_reimport_3d; + StreamTexture2D::request_roughness_callback = _texture_reimport_roughness; + StreamTexture2D::request_normal_callback = _texture_reimport_normal; } ResourceImporterTexture::~ResourceImporterTexture() { diff --git a/editor/import/resource_importer_texture.h b/editor/import/resource_importer_texture.h index e1c71ff1b8..da8ce3c0a8 100644 --- a/editor/import/resource_importer_texture.h +++ b/editor/import/resource_importer_texture.h @@ -37,7 +37,7 @@ #include "scene/resources/texture.h" #include "servers/rendering_server.h" -class StreamTexture; +class StreamTexture2D; class ResourceImporterTexture : public ResourceImporter { GDCLASS(ResourceImporterTexture, ResourceImporter); @@ -72,17 +72,17 @@ protected: Map make_flags; - static void _texture_reimport_roughness(const Ref &p_tex, const String &p_normal_path, RenderingServer::TextureDetectRoughnessChannel p_channel); - static void _texture_reimport_3d(const Ref &p_tex); - static void _texture_reimport_normal(const Ref &p_tex); + static void _texture_reimport_roughness(const Ref &p_tex, const String &p_normal_path, RenderingServer::TextureDetectRoughnessChannel p_channel); + static void _texture_reimport_3d(const Ref &p_tex); + static void _texture_reimport_normal(const Ref &p_tex); static ResourceImporterTexture *singleton; static const char *compression_formats[]; - void _save_stex(const Ref &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_srgb, bool p_force_rgbe, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref &p_normal, Image::RoughnessChannel p_roughness_channel); + void _save_stex(const Ref &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_srgb, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref &p_normal, Image::RoughnessChannel p_roughness_channel); public: - void save_to_stex_format(FileAccess *f, const Ref &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality, bool p_force_rgbe); + static void save_to_stex_format(FileAccess *f, const Ref &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality); static ResourceImporterTexture *get_singleton() { return singleton; } virtual String get_importer_name() const; diff --git a/editor/node_3d_editor_gizmos.cpp b/editor/node_3d_editor_gizmos.cpp index 2a399087b2..cb0d9fa02b 100644 --- a/editor/node_3d_editor_gizmos.cpp +++ b/editor/node_3d_editor_gizmos.cpp @@ -41,6 +41,7 @@ #include "scene/3d/gi_probe.h" #include "scene/3d/gpu_particles_3d.h" #include "scene/3d/light_3d.h" +#include "scene/3d/lightmap_probe.h" #include "scene/3d/listener_3d.h" #include "scene/3d/mesh_instance_3d.h" #include "scene/3d/navigation_region_3d.h" @@ -3069,136 +3070,296 @@ void GIProbeGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { } //// -#if 0 -BakedIndirectLightGizmoPlugin::BakedIndirectLightGizmoPlugin() { - Color gizmo_color = EDITOR_DEF("editors/3d_gizmos/gizmo_colors/baked_indirect_light", Color(0.5, 0.6, 1)); - create_material("baked_indirect_light_material", gizmo_color); +BakedLightmapGizmoPlugin::BakedLightmapGizmoPlugin() { + Color gizmo_color = EDITOR_DEF("editors/3d_gizmos/gizmo_colors/lightmap_lines", Color(0.5, 0.6, 1)); gizmo_color.a = 0.1; - create_material("baked_indirect_light_internal_material", gizmo_color); + create_material("lightmap_lines", gizmo_color); - create_icon_material("baked_indirect_light_icon", Node3DEditor::get_singleton()->get_icon("GizmoBakedLightmap", "EditorIcons")); - create_handle_material("handles"); + Ref mat = memnew(StandardMaterial3D); + mat->set_shading_mode(StandardMaterial3D::SHADING_MODE_UNSHADED); + mat->set_cull_mode(StandardMaterial3D::CULL_DISABLED); + mat->set_flag(StandardMaterial3D::FLAG_ALBEDO_FROM_VERTEX_COLOR, true); + mat->set_flag(StandardMaterial3D::FLAG_SRGB_VERTEX_COLOR, false); + + add_material("lightmap_probe_material", mat); + + create_icon_material("baked_indirect_light_icon", Node3DEditor::get_singleton()->get_theme_icon("GizmoBakedLightmap", "EditorIcons")); } -String BakedIndirectLightGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { - - switch (p_idx) { - case 0: return "Extents X"; - case 1: return "Extents Y"; - case 2: return "Extents Z"; - } +String BakedLightmapGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { return ""; } -Variant BakedIndirectLightGizmoPlugin::get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const { +Variant BakedLightmapGizmoPlugin::get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const { - BakedLightmap *baker = Object::cast_to(p_gizmo->get_spatial_node()); - return baker->get_extents(); + return Variant(); } -void BakedIndirectLightGizmoPlugin::set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera *p_camera, const Point2 &p_point) { - - BakedLightmap *baker = Object::cast_to(p_gizmo->get_spatial_node()); - - Transform gt = baker->get_global_transform(); - Transform gi = gt.affine_inverse(); - - Vector3 extents = baker->get_extents(); - - Vector3 ray_from = p_camera->project_ray_origin(p_point); - Vector3 ray_dir = p_camera->project_ray_normal(p_point); - - Vector3 sg[2] = { gi.xform(ray_from), gi.xform(ray_from + ray_dir * 16384) }; - - Vector3 axis; - axis[p_idx] = 1.0; - - Vector3 ra, rb; - Geometry::get_closest_points_between_segments(Vector3(), axis * 16384, sg[0], sg[1], ra, rb); - float d = ra[p_idx]; - if (Node3DEditor::get_singleton()->is_snap_enabled()) { - d = Math::stepify(d, Node3DEditor::get_singleton()->get_translate_snap()); - } - - if (d < 0.001) - d = 0.001; - - extents[p_idx] = d; - baker->set_extents(extents); +void BakedLightmapGizmoPlugin::set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point) { } -void BakedIndirectLightGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel) { - - BakedLightmap *baker = Object::cast_to(p_gizmo->get_spatial_node()); - - Vector3 restore = p_restore; - - if (p_cancel) { - baker->set_extents(restore); - return; - } - - UndoRedo *ur = Node3DEditor::get_singleton()->get_undo_redo(); - ur->create_action(TTR("Change Probe Extents")); - ur->add_do_method(baker, "set_extents", baker->get_extents()); - ur->add_undo_method(baker, "set_extents", restore); - ur->commit_action(); +void BakedLightmapGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel) { } -bool BakedIndirectLightGizmoPlugin::has_gizmo(Spatial *p_spatial) { +bool BakedLightmapGizmoPlugin::has_gizmo(Node3D *p_spatial) { return Object::cast_to(p_spatial) != nullptr; } -String BakedIndirectLightGizmoPlugin::get_name() const { +String BakedLightmapGizmoPlugin::get_name() const { return "BakedLightmap"; } -int BakedIndirectLightGizmoPlugin::get_priority() const { +int BakedLightmapGizmoPlugin::get_priority() const { return -1; } -void BakedIndirectLightGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { +void BakedLightmapGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { - BakedLightmap *baker = Object::cast_to(p_gizmo->get_spatial_node()); - - Ref material = get_material("baked_indirect_light_material", p_gizmo); Ref icon = get_material("baked_indirect_light_icon", p_gizmo); - Ref material_internal = get_material("baked_indirect_light_internal_material", p_gizmo); + BakedLightmap *baker = Object::cast_to(p_gizmo->get_spatial_node()); + Ref data = baker->get_light_data(); + + p_gizmo->add_unscaled_billboard(icon, 0.05); + + if (data.is_null()) { + return; + } + + Ref material_lines = get_material("lightmap_lines", p_gizmo); + Ref material_probes = get_material("lightmap_probe_material", p_gizmo); p_gizmo->clear(); Vector lines; - Vector3 extents = baker->get_extents(); + Set lines_found; - AABB aabb = AABB(-extents, extents * 2); - - for (int i = 0; i < 12; i++) { - Vector3 a, b; - aabb.get_edge(i, a, b); - lines.push_back(a); - lines.push_back(b); + Vector points = data->get_capture_points(); + if (points.size() == 0) { + return; + } + Vector sh = data->get_capture_sh(); + if (sh.size() != points.size() * 9) { + return; } - p_gizmo->add_lines(lines, material); + Vector tetrahedrons = data->get_capture_tetrahedra(); - Vector handles; + for (int i = 0; i < tetrahedrons.size(); i += 4) { - for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4; j++) { + for (int k = j + 1; k < 4; k++) { - Vector3 ax; - ax[i] = aabb.position[i] + aabb.size[i]; - handles.push_back(ax); + Vector2i pair; + pair.x = tetrahedrons[i + j]; + pair.y = tetrahedrons[i + k]; + + if (pair.y < pair.x) { + SWAP(pair.x, pair.y); + } + if (lines_found.has(pair)) { + continue; + } + lines_found.insert(pair); + lines.push_back(points[pair.x]); + lines.push_back(points[pair.y]); + } + } } - if (p_gizmo->is_selected()) { - p_gizmo->add_solid_box(material_internal, aabb.get_size()); + p_gizmo->add_lines(lines, material_lines); + + int stack_count = 8; + int sector_count = 16; + + float sector_step = 2 * Math_PI / sector_count; + float stack_step = Math_PI / stack_count; + + Vector vertices; + Vector colors; + Vector indices; + float radius = 0.3; + + for (int p = 0; p < points.size(); p++) { + + int vertex_base = vertices.size(); + Vector3 sh_col[9]; + for (int i = 0; i < 9; i++) { + sh_col[i].x = sh[p * 9 + i].r; + sh_col[i].y = sh[p * 9 + i].g; + sh_col[i].z = sh[p * 9 + i].b; + } + + for (int i = 0; i <= stack_count; ++i) { + float stack_angle = Math_PI / 2 - i * stack_step; // starting from pi/2 to -pi/2 + float xy = radius * Math::cos(stack_angle); // r * cos(u) + float z = radius * Math::sin(stack_angle); // r * sin(u) + + // add (sector_count+1) vertices per stack + // the first and last vertices have same position and normal, but different tex coords + for (int j = 0; j <= sector_count; ++j) { + float sector_angle = j * sector_step; // starting from 0 to 2pi + + // vertex position (x, y, z) + float x = xy * Math::cos(sector_angle); // r * cos(u) * cos(v) + float y = xy * Math::sin(sector_angle); // r * cos(u) * sin(v) + + Vector3 n = Vector3(x, z, y); + vertices.push_back(points[p] + n); + n.normalize(); + + const float c1 = 0.429043; + const float c2 = 0.511664; + const float c3 = 0.743125; + const float c4 = 0.886227; + const float c5 = 0.247708; + Vector3 light = (c1 * sh_col[8] * (n.x * n.x - n.y * n.y) + + c3 * sh_col[6] * n.z * n.z + + c4 * sh_col[0] - + c5 * sh_col[6] + + 2.0 * c1 * sh_col[4] * n.x * n.y + + 2.0 * c1 * sh_col[7] * n.x * n.z + + 2.0 * c1 * sh_col[5] * n.y * n.z + + 2.0 * c2 * sh_col[3] * n.x + + 2.0 * c2 * sh_col[1] * n.y + + 2.0 * c2 * sh_col[2] * n.z); + + colors.push_back(Color(light.x, light.y, light.z, 1)); + } + } + + for (int i = 0; i < stack_count; ++i) { + int k1 = i * (sector_count + 1); // beginning of current stack + int k2 = k1 + sector_count + 1; // beginning of next stack + + for (int j = 0; j < sector_count; ++j, ++k1, ++k2) { + // 2 triangles per sector excluding first and last stacks + // k1 => k2 => k1+1 + if (i != 0) { + indices.push_back(vertex_base + k1); + indices.push_back(vertex_base + k2); + indices.push_back(vertex_base + k1 + 1); + } + + // k1+1 => k2 => k2+1 + if (i != (stack_count - 1)) { + indices.push_back(vertex_base + k1 + 1); + indices.push_back(vertex_base + k2); + indices.push_back(vertex_base + k2 + 1); + } + } + } } - p_gizmo->add_unscaled_billboard(icon, 0.05); - p_gizmo->add_handles(handles, get_material("handles")); + Array array; + array.resize(RS::ARRAY_MAX); + array[RS::ARRAY_VERTEX] = vertices; + array[RS::ARRAY_INDEX] = indices; + array[RS::ARRAY_COLOR] = colors; + + Ref mesh; + mesh.instance(); + mesh->add_surface_from_arrays(Mesh::PRIMITIVE_TRIANGLES, array, Array(), Dictionary(), 0); //no compression + mesh->surface_set_material(0, material_probes); + + p_gizmo->add_mesh(mesh); +} +///////// + +LightmapProbeGizmoPlugin::LightmapProbeGizmoPlugin() { + Color gizmo_color = EDITOR_DEF("editors/3d_gizmos/gizmo_colors/lightprobe_lines", Color(0.5, 0.6, 1)); + + gizmo_color.a = 0.3; + create_material("lightprobe_lines", gizmo_color); +} + +String LightmapProbeGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { + + return ""; +} +Variant LightmapProbeGizmoPlugin::get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const { + + return Variant(); +} +void LightmapProbeGizmoPlugin::set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point) { +} + +void LightmapProbeGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel) { +} + +bool LightmapProbeGizmoPlugin::has_gizmo(Node3D *p_spatial) { + return Object::cast_to(p_spatial) != nullptr; +} + +String LightmapProbeGizmoPlugin::get_name() const { + return "LightmapProbe"; +} + +int LightmapProbeGizmoPlugin::get_priority() const { + return -1; +} + +void LightmapProbeGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { + + Ref material_lines = get_material("lightprobe_lines", p_gizmo); + + p_gizmo->clear(); + + Vector lines; + + int stack_count = 8; + int sector_count = 16; + + float sector_step = 2 * Math_PI / sector_count; + float stack_step = Math_PI / stack_count; + + Vector vertices; + float radius = 0.2; + + for (int i = 0; i <= stack_count; ++i) { + float stack_angle = Math_PI / 2 - i * stack_step; // starting from pi/2 to -pi/2 + float xy = radius * Math::cos(stack_angle); // r * cos(u) + float z = radius * Math::sin(stack_angle); // r * sin(u) + + // add (sector_count+1) vertices per stack + // the first and last vertices have same position and normal, but different tex coords + for (int j = 0; j <= sector_count; ++j) { + float sector_angle = j * sector_step; // starting from 0 to 2pi + + // vertex position (x, y, z) + float x = xy * Math::cos(sector_angle); // r * cos(u) * cos(v) + float y = xy * Math::sin(sector_angle); // r * cos(u) * sin(v) + + Vector3 n = Vector3(x, z, y); + vertices.push_back(n); + } + } + + for (int i = 0; i < stack_count; ++i) { + int k1 = i * (sector_count + 1); // beginning of current stack + int k2 = k1 + sector_count + 1; // beginning of next stack + + for (int j = 0; j < sector_count; ++j, ++k1, ++k2) { + // 2 triangles per sector excluding first and last stacks + // k1 => k2 => k1+1 + if (i != 0) { + lines.push_back(vertices[k1]); + lines.push_back(vertices[k2]); + lines.push_back(vertices[k1]); + lines.push_back(vertices[k1 + 1]); + } + + if (i != (stack_count - 1)) { + lines.push_back(vertices[k1 + 1]); + lines.push_back(vertices[k2]); + lines.push_back(vertices[k2]); + lines.push_back(vertices[k2 + 1]); + } + } + } + + p_gizmo->add_lines(lines, material_lines); } -#endif //// CollisionShape3DGizmoPlugin::CollisionShape3DGizmoPlugin() { diff --git a/editor/node_3d_editor_gizmos.h b/editor/node_3d_editor_gizmos.h index 6432feeecb..c25fff528c 100644 --- a/editor/node_3d_editor_gizmos.h +++ b/editor/node_3d_editor_gizmos.h @@ -321,25 +321,42 @@ public: GIProbeGizmoPlugin(); }; -#if 0 -class BakedIndirectLightGizmoPlugin : public EditorNode3DGizmoPlugin { +class BakedLightmapGizmoPlugin : public EditorNode3DGizmoPlugin { - GDCLASS(BakedIndirectLightGizmoPlugin, EditorNode3DGizmoPlugin); + GDCLASS(BakedLightmapGizmoPlugin, EditorNode3DGizmoPlugin); public: - bool has_gizmo(Spatial *p_spatial); + bool has_gizmo(Node3D *p_spatial); String get_name() const; int get_priority() const; void redraw(EditorNode3DGizmo *p_gizmo); String get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const; Variant get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const; - void set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera *p_camera, const Point2 &p_point); + void set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point); void commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel = false); - BakedIndirectLightGizmoPlugin(); + BakedLightmapGizmoPlugin(); }; -#endif + +class LightmapProbeGizmoPlugin : public EditorNode3DGizmoPlugin { + + GDCLASS(LightmapProbeGizmoPlugin, EditorNode3DGizmoPlugin); + +public: + bool has_gizmo(Node3D *p_spatial); + String get_name() const; + int get_priority() const; + void redraw(EditorNode3DGizmo *p_gizmo); + + String get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const; + Variant get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const; + void set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point); + void commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel = false); + + LightmapProbeGizmoPlugin(); +}; + class CollisionShape3DGizmoPlugin : public EditorNode3DGizmoPlugin { GDCLASS(CollisionShape3DGizmoPlugin, EditorNode3DGizmoPlugin); diff --git a/editor/plugins/baked_lightmap_editor_plugin.cpp b/editor/plugins/baked_lightmap_editor_plugin.cpp index ba161244d6..f754dd4725 100644 --- a/editor/plugins/baked_lightmap_editor_plugin.cpp +++ b/editor/plugins/baked_lightmap_editor_plugin.cpp @@ -28,23 +28,36 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 #include "baked_lightmap_editor_plugin.h" -void BakedLightmapEditorPlugin::_bake() { +void BakedLightmapEditorPlugin::_bake_select_file(const String &p_file) { if (lightmap) { BakedLightmap::BakeError err; if (get_tree()->get_edited_scene_root() && get_tree()->get_edited_scene_root() == lightmap) { - err = lightmap->bake(lightmap); + err = lightmap->bake(lightmap, p_file, bake_func_step); } else { - err = lightmap->bake(lightmap->get_parent()); + err = lightmap->bake(lightmap->get_parent(), p_file, bake_func_step); } + bake_func_end(); + switch (err) { - case BakedLightmap::BAKE_ERROR_NO_SAVE_PATH: - EditorNode::get_singleton()->show_warning(TTR("Can't determine a save path for lightmap images.\nSave your scene (for images to be saved in the same dir), or pick a save path from the BakedLightmap properties.")); - break; + case BakedLightmap::BAKE_ERROR_NO_SAVE_PATH: { + String scene_path = lightmap->get_filename(); + if (scene_path == String()) { + scene_path = lightmap->get_owner()->get_filename(); + } + if (scene_path == String()) { + EditorNode::get_singleton()->show_warning(TTR("Can't determine a save path for lightmap images.\nSave your scene and try again.")); + break; + } + scene_path = scene_path.get_basename() + ".lmbake"; + + file_dialog->set_current_path(scene_path); + file_dialog->popup_centered_ratio(); + + } break; case BakedLightmap::BAKE_ERROR_NO_MESHES: EditorNode::get_singleton()->show_warning(TTR("No meshes to bake. Make sure they contain an UV2 channel and that the 'Bake Light' flag is on.")); break; @@ -57,6 +70,11 @@ void BakedLightmapEditorPlugin::_bake() { } } +void BakedLightmapEditorPlugin::_bake() { + + _bake_select_file(""); +} + void BakedLightmapEditorPlugin::edit(Object *p_object) { BakedLightmap *s = Object::cast_to(p_object); @@ -83,23 +101,20 @@ void BakedLightmapEditorPlugin::make_visible(bool p_visible) { EditorProgress *BakedLightmapEditorPlugin::tmp_progress = nullptr; -void BakedLightmapEditorPlugin::bake_func_begin(int p_steps) { +bool BakedLightmapEditorPlugin::bake_func_step(float p_progress, const String &p_description, void *, bool p_refresh) { - ERR_FAIL_COND(tmp_progress != nullptr); - - tmp_progress = memnew(EditorProgress("bake_lightmaps", TTR("Bake Lightmaps"), p_steps, true)); -} - -bool BakedLightmapEditorPlugin::bake_func_step(int p_step, const String &p_description) { - - ERR_FAIL_COND_V(tmp_progress == nullptr, false); - return tmp_progress->step(p_description, p_step, false); + if (!tmp_progress) { + tmp_progress = memnew(EditorProgress("bake_lightmaps", TTR("Bake Lightmaps"), 1000, false)); + ERR_FAIL_COND_V(tmp_progress == nullptr, false); + } + return tmp_progress->step(p_description, p_progress * 1000, p_refresh); } void BakedLightmapEditorPlugin::bake_func_end() { - ERR_FAIL_COND(tmp_progress == nullptr); - memdelete(tmp_progress); - tmp_progress = nullptr; + if (tmp_progress != nullptr) { + memdelete(tmp_progress); + tmp_progress = nullptr; + } } void BakedLightmapEditorPlugin::_bind_methods() { @@ -111,18 +126,20 @@ BakedLightmapEditorPlugin::BakedLightmapEditorPlugin(EditorNode *p_node) { editor = p_node; bake = memnew(ToolButton); - bake->set_icon(editor->get_gui_base()->get_icon("Bake", "EditorIcons")); + bake->set_icon(editor->get_gui_base()->get_theme_icon("Bake", "EditorIcons")); bake->set_text(TTR("Bake Lightmaps")); bake->hide(); - bake->connect("pressed", this, "_bake"); + bake->connect("pressed", Callable(this, "_bake")); add_control_to_container(CONTAINER_SPATIAL_EDITOR_MENU, bake); lightmap = nullptr; - BakedLightmap::bake_begin_function = bake_func_begin; - BakedLightmap::bake_step_function = bake_func_step; - BakedLightmap::bake_end_function = bake_func_end; + file_dialog = memnew(EditorFileDialog); + file_dialog->set_file_mode(EditorFileDialog::FILE_MODE_SAVE_FILE); + file_dialog->add_filter("*.lmbake ; LightMap Bake"); + file_dialog->set_title(TTR("Select lightmap bake file:")); + file_dialog->connect("file_selected", callable_mp(this, &BakedLightmapEditorPlugin::_bake_select_file)); + bake->add_child(file_dialog); } BakedLightmapEditorPlugin::~BakedLightmapEditorPlugin() { } -#endif diff --git a/editor/plugins/baked_lightmap_editor_plugin.h b/editor/plugins/baked_lightmap_editor_plugin.h index 818cdfe8fa..2dbc09fc1d 100644 --- a/editor/plugins/baked_lightmap_editor_plugin.h +++ b/editor/plugins/baked_lightmap_editor_plugin.h @@ -28,7 +28,6 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 #ifndef BAKED_LIGHTMAP_EDITOR_PLUGIN_H #define BAKED_LIGHTMAP_EDITOR_PLUGIN_H @@ -46,11 +45,12 @@ class BakedLightmapEditorPlugin : public EditorPlugin { ToolButton *bake; EditorNode *editor; + EditorFileDialog *file_dialog; static EditorProgress *tmp_progress; - static void bake_func_begin(int p_steps); - static bool bake_func_step(int p_step, const String &p_description); + static bool bake_func_step(float p_progress, const String &p_description, void *, bool p_refresh); static void bake_func_end(); + void _bake_select_file(const String &p_file); void _bake(); protected: @@ -67,5 +67,4 @@ public: ~BakedLightmapEditorPlugin(); }; -#endif // BAKED_LIGHTMAP_EDITOR_PLUGIN_H #endif diff --git a/editor/plugins/canvas_item_editor_plugin.cpp b/editor/plugins/canvas_item_editor_plugin.cpp index 65c0763e63..b5fcf82d76 100644 --- a/editor/plugins/canvas_item_editor_plugin.cpp +++ b/editor/plugins/canvas_item_editor_plugin.cpp @@ -6161,7 +6161,7 @@ bool CanvasItemEditorViewport::can_drop_data(const Point2 &p_point, const Varian type == "ViewportTexture" || type == "CurveTexture" || type == "GradientTexture" || - type == "StreamTexture" || + type == "StreamTexture2D" || type == "AtlasTexture" || type == "LargeTexture") { Ref texture = Ref(Object::cast_to(*res)); diff --git a/editor/plugins/node_3d_editor_plugin.cpp b/editor/plugins/node_3d_editor_plugin.cpp index 1bf5999906..69f8efa86e 100644 --- a/editor/plugins/node_3d_editor_plugin.cpp +++ b/editor/plugins/node_3d_editor_plugin.cpp @@ -6008,7 +6008,8 @@ void Node3DEditor::_register_all_gizmos() { add_gizmo_plugin(Ref(memnew(ReflectionProbeGizmoPlugin))); add_gizmo_plugin(Ref(memnew(DecalGizmoPlugin))); add_gizmo_plugin(Ref(memnew(GIProbeGizmoPlugin))); - // add_gizmo_plugin(Ref(memnew(BakedIndirectLightGizmoPlugin))); + add_gizmo_plugin(Ref(memnew(BakedLightmapGizmoPlugin))); + add_gizmo_plugin(Ref(memnew(LightmapProbeGizmoPlugin))); add_gizmo_plugin(Ref(memnew(CollisionShape3DGizmoPlugin))); add_gizmo_plugin(Ref(memnew(CollisionPolygon3DGizmoPlugin))); add_gizmo_plugin(Ref(memnew(NavigationRegion3DGizmoPlugin))); diff --git a/editor/plugins/texture_editor_plugin.cpp b/editor/plugins/texture_editor_plugin.cpp index c1184c1c89..7a3e571f16 100644 --- a/editor/plugins/texture_editor_plugin.cpp +++ b/editor/plugins/texture_editor_plugin.cpp @@ -84,8 +84,8 @@ void TextureEditor::_notification(int p_what) { String format; if (Object::cast_to(*texture)) { format = Image::get_format_name(Object::cast_to(*texture)->get_format()); - } else if (Object::cast_to(*texture)) { - format = Image::get_format_name(Object::cast_to(*texture)->get_format()); + } else if (Object::cast_to(*texture)) { + format = Image::get_format_name(Object::cast_to(*texture)->get_format()); } else { format = texture->get_class(); } @@ -144,7 +144,7 @@ TextureEditor::~TextureEditor() { // bool EditorInspectorPluginTexture::can_handle(Object *p_object) { - return Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr; + return Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr || Object::cast_to(p_object) != nullptr; } void EditorInspectorPluginTexture::parse_begin(Object *p_object) { diff --git a/editor/plugins/texture_layered_editor_plugin.cpp b/editor/plugins/texture_layered_editor_plugin.cpp new file mode 100644 index 0000000000..6d716951b3 --- /dev/null +++ b/editor/plugins/texture_layered_editor_plugin.cpp @@ -0,0 +1,286 @@ +/*************************************************************************/ +/* texture_editor_plugin.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "texture_layered_editor_plugin.h" + +#include "core/io/resource_loader.h" +#include "core/project_settings.h" +#include "editor/editor_settings.h" + +void TextureLayeredEditor::_gui_input(Ref p_event) { + Ref mm = p_event; + if (mm.is_valid() && mm->get_button_mask() & BUTTON_MASK_LEFT) { + y_rot += -mm->get_relative().x * 0.01; + x_rot += mm->get_relative().y * 0.01; + _update_material(); + } +} + +void TextureLayeredEditor::_texture_rect_draw() { + texture_rect->draw_rect(Rect2(Point2(), texture_rect->get_size()), Color(1, 1, 1, 1)); +} + +void TextureLayeredEditor::_notification(int p_what) { + + if (p_what == NOTIFICATION_READY) { + + //get_scene()->connect("node_removed",this,"_node_removed"); + } + if (p_what == NOTIFICATION_RESIZED) { + _texture_rect_update_area(); + } + + if (p_what == NOTIFICATION_DRAW) { + + Ref checkerboard = get_theme_icon("Checkerboard", "EditorIcons"); + Size2 size = get_size(); + + draw_texture_rect(checkerboard, Rect2(Point2(), size), true); + } +} + +void TextureLayeredEditor::_changed_callback(Object *p_changed, const char *p_prop) { + + if (!is_visible()) + return; + update(); +} + +void TextureLayeredEditor::_update_material() { + + materials[0]->set_shader_param("layer", layer->get_value()); + materials[2]->set_shader_param("layer", layer->get_value()); + materials[texture->get_layered_type()]->set_shader_param("tex", texture->get_rid()); + + Vector3 v(1, 1, 1); + v.normalize(); + + Basis b; + b.rotate(Vector3(1, 0, 0), x_rot); + b.rotate(Vector3(0, 1, 0), y_rot); + + materials[1]->set_shader_param("normal", v); + materials[1]->set_shader_param("rot", b); + materials[2]->set_shader_param("normal", v); + materials[2]->set_shader_param("rot", b); + + String format = Image::get_format_name(texture->get_format()); + + String text; + if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_2D_ARRAY) { + text = itos(texture->get_width()) + "x" + itos(texture->get_height()) + " (x " + itos(texture->get_layers()) + ")" + format; + } else if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_CUBEMAP) { + text = itos(texture->get_width()) + "x" + itos(texture->get_height()) + " " + format; + } else if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_CUBEMAP_ARRAY) { + text = itos(texture->get_width()) + "x" + itos(texture->get_height()) + " (x " + itos(texture->get_layers() / 6) + ")" + format; + } + + info->set_text(text); +} + +void TextureLayeredEditor::_make_shaders() { + String shader_2d_array = "" + "shader_type canvas_item;\n" + "uniform sampler2DArray tex;\n" + "uniform float layer;\n" + "void fragment() {\n" + " COLOR = textureLod(tex,vec3(UV,layer),0.0);\n" + "}"; + + shaders[0].instance(); + shaders[0]->set_code(shader_2d_array); + + String shader_cube = "" + "shader_type canvas_item;\n" + "uniform samplerCube tex;\n" + "uniform vec3 normal;\n" + "uniform mat3 rot;\n" + "void fragment() {\n" + " vec3 n = rot * normalize(vec3(normal.xy*(UV * 2.0 - 1.0),normal.z));\n" + " COLOR = textureLod(tex,n,0.0);\n" + "}"; + + shaders[1].instance(); + shaders[1]->set_code(shader_cube); + + String shader_cube_array = "" + "shader_type canvas_item;\n" + "uniform samplerCubeArray tex;\n" + "uniform vec3 normal;\n" + "uniform mat3 rot;\n" + "uniform float layer;\n" + "void fragment() {\n" + " vec3 n = rot * normalize(vec3(normal.xy*(UV * 2.0 - 1.0),normal.z));\n" + " COLOR = textureLod(tex,vec4(n,layer),0.0);\n" + "}"; + + shaders[2].instance(); + shaders[2]->set_code(shader_cube_array); + + for (int i = 0; i < 3; i++) { + materials[i].instance(); + materials[i]->set_shader(shaders[i]); + } +} + +void TextureLayeredEditor::_texture_rect_update_area() { + + Size2 size = get_size(); + int tex_width = texture->get_width() * size.height / texture->get_height(); + int tex_height = size.height; + + if (tex_width > size.width) { + tex_width = size.width; + tex_height = texture->get_height() * tex_width / texture->get_width(); + } + + // Prevent the texture from being unpreviewable after the rescale, so that we can still see something + if (tex_height <= 0) + tex_height = 1; + if (tex_width <= 0) + tex_width = 1; + + int ofs_x = (size.width - tex_width) / 2; + int ofs_y = (size.height - tex_height) / 2; + + texture_rect->set_position(Vector2(ofs_x, ofs_y)); + texture_rect->set_size(Vector2(tex_width, tex_height)); +} + +void TextureLayeredEditor::edit(Ref p_texture) { + + if (!texture.is_null()) + texture->remove_change_receptor(this); + + texture = p_texture; + + if (!texture.is_null()) { + + if (shaders[0].is_null()) { + _make_shaders(); + } + + texture->add_change_receptor(this); + update(); + texture_rect->set_material(materials[texture->get_layered_type()]); + setting = true; + if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_2D_ARRAY) { + layer->set_max(texture->get_layers() - 1); + layer->set_value(0); + layer->show(); + } else if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_CUBEMAP_ARRAY) { + layer->set_max(texture->get_layers() / 6 - 1); + layer->set_value(0); + layer->show(); + } else { + layer->hide(); + } + x_rot = 0; + y_rot = 0; + _update_material(); + setting = false; + _texture_rect_update_area(); + } else { + hide(); + } +} + +void TextureLayeredEditor::_bind_methods() { + + ClassDB::bind_method(D_METHOD("_gui_input"), &TextureLayeredEditor::_gui_input); + ClassDB::bind_method(D_METHOD("_layer_changed"), &TextureLayeredEditor::_layer_changed); +} + +TextureLayeredEditor::TextureLayeredEditor() { + + set_texture_repeat(TextureRepeat::TEXTURE_REPEAT_ENABLED); + set_custom_minimum_size(Size2(1, 150)); + texture_rect = memnew(Control); + texture_rect->connect("draw", callable_mp(this, &TextureLayeredEditor::_texture_rect_draw)); + texture_rect->set_mouse_filter(MOUSE_FILTER_IGNORE); + add_child(texture_rect); + + layer = memnew(SpinBox); + layer->set_step(1); + layer->set_max(100); + add_child(layer); + layer->set_anchor(MARGIN_RIGHT, 1); + layer->set_anchor(MARGIN_LEFT, 1); + layer->set_h_grow_direction(GROW_DIRECTION_BEGIN); + layer->set_modulate(Color(1, 1, 1, 0.8)); + info = memnew(Label); + add_child(info); + info->set_anchor(MARGIN_RIGHT, 1); + info->set_anchor(MARGIN_LEFT, 1); + info->set_anchor(MARGIN_BOTTOM, 1); + info->set_anchor(MARGIN_TOP, 1); + info->set_h_grow_direction(GROW_DIRECTION_BEGIN); + info->set_v_grow_direction(GROW_DIRECTION_BEGIN); + info->add_theme_color_override("font_color", Color(1, 1, 1, 1)); + info->add_theme_color_override("font_color_shadow", Color(0, 0, 0, 0.5)); + info->add_theme_color_override("font_color_shadow", Color(0, 0, 0, 0.5)); + info->add_theme_constant_override("shadow_as_outline", 1); + info->add_theme_constant_override("shadow_offset_x", 2); + info->add_theme_constant_override("shadow_offset_y", 2); + + setting = false; + layer->connect("value_changed", Callable(this, "_layer_changed")); +} + +TextureLayeredEditor::~TextureLayeredEditor() { + if (!texture.is_null()) { + texture->remove_change_receptor(this); + } +} +// +bool EditorInspectorPluginLayeredTexture::can_handle(Object *p_object) { + + return Object::cast_to(p_object) != nullptr; +} + +void EditorInspectorPluginLayeredTexture::parse_begin(Object *p_object) { + + TextureLayered *texture = Object::cast_to(p_object); + if (!texture) { + return; + } + Ref m(texture); + + TextureLayeredEditor *editor = memnew(TextureLayeredEditor); + editor->edit(m); + add_custom_control(editor); +} + +TextureLayeredEditorPlugin::TextureLayeredEditorPlugin(EditorNode *p_node) { + + Ref plugin; + plugin.instance(); + add_inspector_plugin(plugin); +} diff --git a/editor/plugins/texture_layered_editor_plugin.h b/editor/plugins/texture_layered_editor_plugin.h new file mode 100644 index 0000000000..e8503e845e --- /dev/null +++ b/editor/plugins/texture_layered_editor_plugin.h @@ -0,0 +1,95 @@ +/*************************************************************************/ +/* texture_editor_plugin.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef TEXTURE_LAYERED_EDITOR_PLUGIN_H +#define TEXTURE_LAYERED_EDITOR_PLUGIN_H + +#include "editor/editor_node.h" +#include "editor/editor_plugin.h" +#include "scene/resources/shader.h" +#include "scene/resources/texture.h" +class TextureLayeredEditor : public Control { + + GDCLASS(TextureLayeredEditor, Control); + + SpinBox *layer; + Label *info; + Ref texture; + + Ref shaders[3]; + Ref materials[3]; + + float x_rot = 0; + float y_rot = 0; + Control *texture_rect; + + void _make_shaders(); + + void _update_material(); + bool setting; + void _layer_changed(double) { + if (!setting) + _update_material(); + } + + void _texture_rect_update_area(); + void _texture_rect_draw(); + +protected: + void _notification(int p_what); + void _gui_input(Ref p_event); + void _changed_callback(Object *p_changed, const char *p_prop); + static void _bind_methods(); + +public: + void edit(Ref p_texture); + TextureLayeredEditor(); + ~TextureLayeredEditor(); +}; + +class EditorInspectorPluginLayeredTexture : public EditorInspectorPlugin { + GDCLASS(EditorInspectorPluginLayeredTexture, EditorInspectorPlugin); + +public: + virtual bool can_handle(Object *p_object); + virtual void parse_begin(Object *p_object); +}; + +class TextureLayeredEditorPlugin : public EditorPlugin { + + GDCLASS(TextureLayeredEditorPlugin, EditorPlugin); + +public: + virtual String get_name() const { return "TextureLayered"; } + + TextureLayeredEditorPlugin(EditorNode *p_node); +}; + +#endif // TEXTURE_EDITOR_PLUGIN_H diff --git a/gles_builders.py b/gles_builders.py index 6ff2f4248b..85a8d7aa15 100644 --- a/gles_builders.py +++ b/gles_builders.py @@ -740,5 +740,69 @@ def build_rd_headers(target, source, env): build_rd_header(str(x)) +class RAWHeaderStruct: + def __init__(self): + self.code = "" + + +def include_file_in_raw_header(filename, header_data, depth): + fs = open(filename, "r") + line = fs.readline() + text = "" + + while line: + + while line.find("#include ") != -1: + includeline = line.replace("#include ", "").strip()[1:-1] + + import os.path + + included_file = os.path.relpath(os.path.dirname(filename) + "/" + includeline) + include_file_in_raw_header(included_file, header_data, depth + 1) + + line = fs.readline() + + header_data.code += line + line = fs.readline() + + fs.close() + + +def build_raw_header(filename): + header_data = RAWHeaderStruct() + include_file_in_raw_header(filename, header_data, 0) + + out_file = filename + ".gen.h" + fd = open(out_file, "w") + + enum_constants = [] + + fd.write("/* WARNING, THIS FILE WAS GENERATED, DO NOT EDIT */\n") + + out_file_base = out_file.replace(".glsl.gen.h", "_shader_glsl") + out_file_base = out_file_base[out_file_base.rfind("/") + 1 :] + out_file_base = out_file_base[out_file_base.rfind("\\") + 1 :] + out_file_ifdef = out_file_base.replace(".", "_").upper() + fd.write("#ifndef " + out_file_ifdef + "_RAW_H\n") + fd.write("#define " + out_file_ifdef + "_RAW_H\n") + fd.write("\n") + fd.write("static const char " + out_file_base + "[] = {\n") + for c in header_data.code: + fd.write(str(ord(c)) + ",") + fd.write("\t\t0};\n\n") + fd.write("#endif\n") + fd.close() + + +def build_rd_headers(target, source, env): + for x in source: + build_rd_header(str(x)) + + +def build_raw_headers(target, source, env): + for x in source: + build_raw_header(str(x)) + + if __name__ == "__main__": subprocess_main(globals()) diff --git a/main/main.cpp b/main/main.cpp index b6afd9160c..65c6fdd397 100644 --- a/main/main.cpp +++ b/main/main.cpp @@ -2063,9 +2063,11 @@ bool Main::start() { } if (project_manager || editor) { - // Hide console window if requested (Windows-only). - bool hide_console = EditorSettings::get_singleton()->get_setting("interface/editor/hide_console_window"); - DisplayServer::get_singleton()->console_set_visible(!hide_console); + if (DisplayServer::get_singleton()->has_feature(DisplayServer::FEATURE_CONSOLE_WINDOW)) { + // Hide console window if requested (Windows-only). + bool hide_console = EditorSettings::get_singleton()->get_setting("interface/editor/hide_console_window"); + DisplayServer::get_singleton()->console_set_visible(!hide_console); + } // Load SSL Certificates from Editor Settings (or builtin) Crypto::load_default_certificates(EditorSettings::get_singleton()->get_setting("network/ssl/editor_ssl_certificates").operator String()); diff --git a/main/tests/test_math.cpp b/main/tests/test_math.cpp index b6ef573b36..fbd1aa275a 100644 --- a/main/tests/test_math.cpp +++ b/main/tests/test_math.cpp @@ -32,8 +32,10 @@ #include "core/math/basis.h" #include "core/math/camera_matrix.h" +#include "core/math/delaunay_3d.h" #include "core/math/math_funcs.h" #include "core/math/transform.h" +#include "core/method_ptrcall.h" #include "core/os/file_access.h" #include "core/os/keyboard.h" #include "core/os/os.h" @@ -45,8 +47,6 @@ #include "scene/resources/texture.h" #include "servers/rendering/shader_language.h" -#include "core/method_ptrcall.h" - namespace TestMath { class GetClassAndNamespace { @@ -414,6 +414,55 @@ uint32_t ihash3(uint32_t a) { MainLoop *test() { + { + Vector points; + points.push_back(Vector3(0, 0, 0)); + points.push_back(Vector3(0, 0, 1)); + points.push_back(Vector3(0, 1, 0)); + points.push_back(Vector3(0, 1, 1)); + points.push_back(Vector3(1, 1, 0)); + points.push_back(Vector3(1, 0, 0)); + points.push_back(Vector3(1, 0, 1)); + points.push_back(Vector3(1, 1, 1)); + + for (int i = 0; i < 800; i++) { + points.push_back(Vector3(Math::randf() * 2.0 - 1.0, Math::randf() * 2.0 - 1.0, Math::randf() * 2.0 - 1.0) * Vector3(25, 30, 33)); + } + + Vector os = Delaunay3D::tetrahedralize(points); + print_line("simplices in the end: " + itos(os.size())); + for (int i = 0; i < os.size(); i++) { + print_line("Simplex " + itos(i) + ": "); + print_line(points[os[i].points[0]]); + print_line(points[os[i].points[1]]); + print_line(points[os[i].points[2]]); + print_line(points[os[i].points[3]]); + } + + { + FileAccessRef f = FileAccess::open("res://bsp.obj", FileAccess::WRITE); + for (int i = 0; i < os.size(); i++) { + f->store_line("o Simplex" + itos(i)); + for (int j = 0; j < 4; j++) { + f->store_line(vformat("v %f %f %f", points[os[i].points[j]].x, points[os[i].points[j]].y, points[os[i].points[j]].z)); + } + static const int face_order[4][3] = { + { 1, 2, 3 }, + { 1, 3, 4 }, + { 1, 2, 4 }, + { 2, 3, 4 } + }; + + for (int j = 0; j < 4; j++) { + f->store_line(vformat("f %d %d %d", 4 * i + face_order[j][0], 4 * i + face_order[j][1], 4 * i + face_order[j][2])); + } + } + f->close(); + } + + return nullptr; + } + { float r = 1; float g = 0.5; diff --git a/modules/denoise/SCsub b/modules/denoise/SCsub new file mode 100644 index 0000000000..8cf91b7dbd --- /dev/null +++ b/modules/denoise/SCsub @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +import resource_to_cpp +from platform_methods import run_in_subprocess + +Import("env") +Import("env_modules") + +env_oidn = env_modules.Clone() + +# Thirdparty source files +thirdparty_dir = "#thirdparty/oidn/" +thirdparty_sources = [ + "core/api.cpp", + "core/device.cpp", + "core/filter.cpp", + "core/network.cpp", + "core/autoencoder.cpp", + "core/transfer_function.cpp", + "weights/rtlightmap_hdr.cpp", + "mkl-dnn/src/common/batch_normalization.cpp", + "mkl-dnn/src/common/concat.cpp", + "mkl-dnn/src/common/convolution.cpp", + "mkl-dnn/src/common/convolution_pd.cpp", + "mkl-dnn/src/common/deconvolution.cpp", + "mkl-dnn/src/common/eltwise.cpp", + "mkl-dnn/src/common/engine.cpp", + "mkl-dnn/src/common/inner_product.cpp", + "mkl-dnn/src/common/inner_product_pd.cpp", + "mkl-dnn/src/common/lrn.cpp", + "mkl-dnn/src/common/memory.cpp", + "mkl-dnn/src/common/memory_desc_wrapper.cpp", + "mkl-dnn/src/common/mkldnn_debug.cpp", + "mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp", + "mkl-dnn/src/common/pooling.cpp", + "mkl-dnn/src/common/primitive.cpp", + "mkl-dnn/src/common/primitive_attr.cpp", + "mkl-dnn/src/common/primitive_desc.cpp", + "mkl-dnn/src/common/primitive_exec_types.cpp", + "mkl-dnn/src/common/primitive_iterator.cpp", + "mkl-dnn/src/common/query.cpp", + "mkl-dnn/src/common/reorder.cpp", + "mkl-dnn/src/common/rnn.cpp", + "mkl-dnn/src/common/scratchpad.cpp", + "mkl-dnn/src/common/shuffle.cpp", + "mkl-dnn/src/common/softmax.cpp", + "mkl-dnn/src/common/stream.cpp", + "mkl-dnn/src/common/sum.cpp", + "mkl-dnn/src/common/utils.cpp", + "mkl-dnn/src/common/verbose.cpp", + "mkl-dnn/src/cpu/cpu_barrier.cpp", + "mkl-dnn/src/cpu/cpu_concat.cpp", + "mkl-dnn/src/cpu/cpu_engine.cpp", + "mkl-dnn/src/cpu/cpu_memory.cpp", + "mkl-dnn/src/cpu/cpu_reducer.cpp", + "mkl-dnn/src/cpu/cpu_reorder.cpp", + "mkl-dnn/src/cpu/cpu_sum.cpp", + "mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_avx2_convolution.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp", + "mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_sse42_convolution.cpp", + "mkl-dnn/src/cpu/jit_transpose_src_utils.cpp", + "mkl-dnn/src/cpu/jit_uni_eltwise.cpp", + "mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_uni_pooling.cpp", + "mkl-dnn/src/cpu/jit_uni_reorder.cpp", + "mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp", + "mkl-dnn/src/cpu/jit_utils/jit_utils.cpp", + "mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c", + "common/platform.cpp", + "common/thread.cpp", + "common/tensor.cpp", +] +thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources] + +thirdparty_include_dirs = [ + "", + "include", + "mkl-dnn/include", + "mkl-dnn/src", + "mkl-dnn/src/common", + "mkl-dnn/src/cpu/xbyak", + "mkl-dnn/src/cpu", +] +thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs] + + +env_oidn.Prepend(CPPPATH=thirdparty_include_dirs) +env_oidn.Append( + CPPDEFINES=[ + "MKLDNN_THR=MKLDNN_THR_SEQ", + "OIDN_STATIC_LIB", + "__STDC_CONSTANT_MACROS", + "__STDC_LIMIT_MACROS", + "DISABLE_VERBOSE", + "MKLDNN_ENABLE_CONCURRENT_EXEC", + "NDEBUG", + ] +) + +env_thirdparty = env_oidn.Clone() +env_thirdparty.disable_warnings() +env_thirdparty.add_source_files(env.modules_sources, thirdparty_sources) + +weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza" +weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.cpp" + +env_thirdparty.Depends(weights_out_path, weights_in_path) +env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp) + +env_oidn.add_source_files(env.modules_sources, "denoise_wrapper.cpp") +env_modules.add_source_files(env.modules_sources, ["register_types.cpp", "lightmap_denoiser.cpp"]) diff --git a/modules/denoise/config.py b/modules/denoise/config.py new file mode 100644 index 0000000000..53b8f2f2e3 --- /dev/null +++ b/modules/denoise/config.py @@ -0,0 +1,6 @@ +def can_build(env, platform): + return env["tools"] + + +def configure(env): + pass diff --git a/modules/denoise/denoise_wrapper.cpp b/modules/denoise/denoise_wrapper.cpp new file mode 100644 index 0000000000..feeeaef507 --- /dev/null +++ b/modules/denoise/denoise_wrapper.cpp @@ -0,0 +1,34 @@ +#include "denoise_wrapper.h" +#include "thirdparty/oidn/include/OpenImageDenoise/oidn.h" +#include + +void *oidn_denoiser_init() { + OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU); + oidnCommitDevice(device); + return device; +} + +bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) { + OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr; + OIDNFilter filter = oidnNewFilter(device, "RTLightmap"); + oidnSetSharedFilterImage(filter, "color", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); + oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); + oidnSetFilter1b(filter, "hdr", true); + //oidnSetFilter1f(filter, "hdrScale", 1.0f); + oidnCommitFilter(filter); + oidnExecuteFilter(filter); + + const char *msg; + bool success = true; + if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) { + printf("LightmapDenoiser: %s\n", msg); + success = false; + } + + oidnReleaseFilter(filter); + return success; +} + +void oidn_denoiser_finish(void *device) { + oidnReleaseDevice((OIDNDeviceImpl *)device); +} diff --git a/modules/denoise/denoise_wrapper.h b/modules/denoise/denoise_wrapper.h new file mode 100644 index 0000000000..3aef326e22 --- /dev/null +++ b/modules/denoise/denoise_wrapper.h @@ -0,0 +1,8 @@ +#ifndef DENOISE_WRAPPER_H +#define DENOISE_WRAPPER_H + +void *oidn_denoiser_init(); +bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height); +void oidn_denoiser_finish(void *device); + +#endif // DENOISE_WRAPPER_H diff --git a/modules/denoise/lightmap_denoiser.cpp b/modules/denoise/lightmap_denoiser.cpp new file mode 100644 index 0000000000..c821b22d85 --- /dev/null +++ b/modules/denoise/lightmap_denoiser.cpp @@ -0,0 +1,63 @@ +/*************************************************************************/ +/* lightmap_denoiser.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "lightmap_denoiser.h" +#include "denoise_wrapper.h" + +LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() { + return memnew(LightmapDenoiserOIDN); +} + +void LightmapDenoiserOIDN::make_default_denoiser() { + create_function = create_oidn_denoiser; +} + +Ref LightmapDenoiserOIDN::denoise_image(const Ref &p_image) { + + Ref img = p_image->duplicate(); + + img->convert(Image::FORMAT_RGBF); + + Vector data = img->get_data(); + if (!oidn_denoise(device, (float *)data.ptrw(), img->get_width(), img->get_height())) { + return p_image; + } + + img->create(img->get_width(), img->get_height(), false, img->get_format(), data); + return img; +} + +LightmapDenoiserOIDN::LightmapDenoiserOIDN() { + device = oidn_denoiser_init(); +} + +LightmapDenoiserOIDN::~LightmapDenoiserOIDN() { + oidn_denoiser_finish(device); +} diff --git a/modules/denoise/lightmap_denoiser.h b/modules/denoise/lightmap_denoiser.h new file mode 100644 index 0000000000..ac0cc8b9db --- /dev/null +++ b/modules/denoise/lightmap_denoiser.h @@ -0,0 +1,57 @@ +/*************************************************************************/ +/* lightmap_denoiser.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LIGHTMAP_DENOISER_H +#define LIGHTMAP_DENOISER_H + +#include "core/object.h" +#include "scene/3d/lightmapper.h" + +struct OIDNDeviceImpl; + +class LightmapDenoiserOIDN : public LightmapDenoiser { + + GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser); + +protected: + void *device = nullptr; + +public: + static LightmapDenoiser *create_oidn_denoiser(); + + Ref denoise_image(const Ref &p_image); + + static void make_default_denoiser(); + + LightmapDenoiserOIDN(); + ~LightmapDenoiserOIDN(); +}; + +#endif // LIGHTMAP_DENOISER_H diff --git a/modules/denoise/register_types.cpp b/modules/denoise/register_types.cpp new file mode 100644 index 0000000000..b6b92701c8 --- /dev/null +++ b/modules/denoise/register_types.cpp @@ -0,0 +1,41 @@ +/*************************************************************************/ +/* register_types.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "register_types.h" +#include "core/engine.h" +#include "lightmap_denoiser.h" + +void register_denoise_types() { + + LightmapDenoiserOIDN::make_default_denoiser(); +} + +void unregister_denoise_types() { +} diff --git a/modules/denoise/register_types.h b/modules/denoise/register_types.h new file mode 100644 index 0000000000..2ffc36ee2c --- /dev/null +++ b/modules/denoise/register_types.h @@ -0,0 +1,32 @@ +/*************************************************************************/ +/* register_types.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +void register_denoise_types(); +void unregister_denoise_types(); diff --git a/modules/denoise/resource_to_cpp.py b/modules/denoise/resource_to_cpp.py new file mode 100644 index 0000000000..4c0b67f701 --- /dev/null +++ b/modules/denoise/resource_to_cpp.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +## ======================================================================== ## +## Copyright 2009-2019 Intel Corporation ## +## ## +## Licensed under the Apache License, Version 2.0 (the "License"); ## +## you may not use this file except in compliance with the License. ## +## You may obtain a copy of the License at ## +## ## +## http://www.apache.org/licenses/LICENSE-2.0 ## +## ## +## Unless required by applicable law or agreed to in writing, software ## +## distributed under the License is distributed on an "AS IS" BASIS, ## +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ## +## See the License for the specific language governing permissions and ## +## limitations under the License. ## +## ======================================================================== ## + +import os +import sys +import argparse +from array import array + +# Generates a C++ file from the specified binary resource file +def generate(in_path, out_path): + + namespace = "oidn::weights" + scopes = namespace.split("::") + + file_name = os.path.basename(in_path) + var_name = os.path.splitext(file_name)[0] + + with open(in_path, "rb") as in_file, open(out_path, "w") as out_file: + # Header + out_file.write("// Generated from: %s\n" % file_name) + out_file.write("#include \n\n") + + # Open the namespaces + for s in scopes: + out_file.write("namespace %s {\n" % s) + if scopes: + out_file.write("\n") + + # Read the file + in_data = array("B", in_file.read()) + + # Write the size + out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data))) + + # Write the data + out_file.write("unsigned char %s[] = {" % var_name) + for i in range(len(in_data)): + c = in_data[i] + if i > 0: + out_file.write(",") + if (i + 1) % 20 == 1: + out_file.write("\n") + out_file.write("%d" % c) + out_file.write("\n};\n") + + # Close the namespaces + if scopes: + out_file.write("\n") + for scope in reversed(scopes): + out_file.write("} // namespace %s\n" % scope) + + +def tza_to_cpp(target, source, env): + for x in zip(source, target): + generate(str(x[0]), str(x[1])) diff --git a/modules/lightmapper_rd/SCsub b/modules/lightmapper_rd/SCsub new file mode 100644 index 0000000000..2f04f1833e --- /dev/null +++ b/modules/lightmapper_rd/SCsub @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +Import("env") +Import("env_modules") + +env_lightmapper_rd = env_modules.Clone() +env_lightmapper_rd.GLSL_HEADER("lm_raster.glsl") +env_lightmapper_rd.GLSL_HEADER("lm_compute.glsl") +env_lightmapper_rd.GLSL_HEADER("lm_blendseams.glsl") + +# Godot source files +env_lightmapper_rd.add_source_files(env.modules_sources, "*.cpp") diff --git a/modules/lightmapper_rd/config.py b/modules/lightmapper_rd/config.py new file mode 100644 index 0000000000..d22f9454ed --- /dev/null +++ b/modules/lightmapper_rd/config.py @@ -0,0 +1,6 @@ +def can_build(env, platform): + return True + + +def configure(env): + pass diff --git a/modules/lightmapper_rd/lightmapper_rd.cpp b/modules/lightmapper_rd/lightmapper_rd.cpp new file mode 100644 index 0000000000..6983c222c0 --- /dev/null +++ b/modules/lightmapper_rd/lightmapper_rd.cpp @@ -0,0 +1,1754 @@ +#include "lightmapper_rd.h" +#include "core/math/geometry.h" +#include "core/project_settings.h" +#include "lm_blendseams.glsl.gen.h" +#include "lm_compute.glsl.gen.h" +#include "lm_raster.glsl.gen.h" +#include "servers/rendering/rendering_device_binds.h" + +//uncomment this if you want to see textures from all the process saved +//#define DEBUG_TEXTURES + +void LightmapperRD::add_mesh(const MeshData &p_mesh) { + ERR_FAIL_COND(p_mesh.albedo_on_uv2.is_null() || p_mesh.albedo_on_uv2->empty()); + ERR_FAIL_COND(p_mesh.emission_on_uv2.is_null() || p_mesh.emission_on_uv2->empty()); + ERR_FAIL_COND(p_mesh.albedo_on_uv2->get_width() != p_mesh.emission_on_uv2->get_width()); + ERR_FAIL_COND(p_mesh.albedo_on_uv2->get_height() != p_mesh.emission_on_uv2->get_height()); + ERR_FAIL_COND(p_mesh.points.size() == 0); + MeshInstance mi; + mi.data = p_mesh; + mesh_instances.push_back(mi); +} + +void LightmapperRD::add_directional_light(bool p_static, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_angular_distance) { + Light l; + l.type = LIGHT_TYPE_DIRECTIONAL; + l.direction[0] = p_direction.x; + l.direction[1] = p_direction.y; + l.direction[2] = p_direction.z; + l.color[0] = p_color.r; + l.color[1] = p_color.g; + l.color[2] = p_color.b; + l.energy = p_energy; + l.static_bake = p_static; + l.size = p_angular_distance; + lights.push_back(l); +} +void LightmapperRD::add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size) { + Light l; + l.type = LIGHT_TYPE_OMNI; + l.position[0] = p_position.x; + l.position[1] = p_position.y; + l.position[2] = p_position.z; + l.range = p_range; + l.attenuation = p_attenuation; + l.color[0] = p_color.r; + l.color[1] = p_color.g; + l.color[2] = p_color.b; + l.energy = p_energy; + l.static_bake = p_static; + l.size = p_size; + lights.push_back(l); +} +void LightmapperRD::add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size) { + + Light l; + l.type = LIGHT_TYPE_SPOT; + l.position[0] = p_position.x; + l.position[1] = p_position.y; + l.position[2] = p_position.z; + l.direction[0] = p_direction.x; + l.direction[1] = p_direction.y; + l.direction[2] = p_direction.z; + l.range = p_range; + l.attenuation = p_attenuation; + l.spot_angle = Math::deg2rad(p_spot_angle); + l.spot_attenuation = p_spot_attenuation; + l.color[0] = p_color.r; + l.color[1] = p_color.g; + l.color[2] = p_color.b; + l.energy = p_energy; + l.static_bake = p_static; + l.size = p_size; + lights.push_back(l); +} + +void LightmapperRD::add_probe(const Vector3 &p_position) { + Probe probe; + probe.position[0] = p_position.x; + probe.position[1] = p_position.y; + probe.position[2] = p_position.z; + probe.position[3] = 0; + probe_positions.push_back(probe); +} + +void LightmapperRD::_plot_triangle_into_triangle_index_list(int p_size, const Vector3i &p_ofs, const AABB &p_bounds, const Vector3 p_points[3], uint32_t p_triangle_index, LocalVector &triangles, uint32_t p_grid_size) { + + int half_size = p_size / 2; + + for (int i = 0; i < 8; i++) { + + AABB aabb = p_bounds; + aabb.size *= 0.5; + Vector3i n = p_ofs; + + if (i & 1) { + aabb.position.x += aabb.size.x; + n.x += half_size; + } + if (i & 2) { + aabb.position.y += aabb.size.y; + n.y += half_size; + } + if (i & 4) { + aabb.position.z += aabb.size.z; + n.z += half_size; + } + + { + Vector3 qsize = aabb.size * 0.5; //quarter size, for fast aabb test + + if (!Geometry::triangle_box_overlap(aabb.position + qsize, qsize, p_points)) { + //does not fit in child, go on + continue; + } + } + + if (half_size == 1) { + //got to the end + TriangleSort ts; + ts.cell_index = n.x + (n.y * p_grid_size) + (n.z * p_grid_size * p_grid_size); + ts.triangle_index = p_triangle_index; + triangles.push_back(ts); + } else { + _plot_triangle_into_triangle_index_list(half_size, n, aabb, p_points, p_triangle_index, triangles, p_grid_size); + } + } +} + +Lightmapper::BakeError LightmapperRD::_blit_meshes_into_atlas(int p_max_texture_size, Vector> &albedo_images, Vector> &emission_images, AABB &bounds, Size2i &atlas_size, int &atlas_slices, BakeStepFunc p_step_function, void *p_bake_userdata) { + + Vector sizes; + + for (int m_i = 0; m_i < mesh_instances.size(); m_i++) { + + MeshInstance &mi = mesh_instances.write[m_i]; + Size2i s = Size2i(mi.data.albedo_on_uv2->get_width(), mi.data.albedo_on_uv2->get_height()); + sizes.push_back(s); + atlas_size.width = MAX(atlas_size.width, s.width); + atlas_size.height = MAX(atlas_size.height, s.height); + } + + int max = nearest_power_of_2_templated(atlas_size.width); + max = MAX(max, nearest_power_of_2_templated(atlas_size.height)); + + if (max > p_max_texture_size) { + return BAKE_ERROR_LIGHTMAP_TOO_SMALL; + } + + if (p_step_function) { + p_step_function(0.1, TTR("Determining optimal atlas size"), p_bake_userdata, true); + } + + atlas_size = Size2i(max, max); + + Size2i best_atlas_size; + int best_atlas_slices = 0; + int best_atlas_memory = 0x7FFFFFFF; + Vector best_atlas_offsets; + + //determine best texture array atlas size by bruteforce fitting + while (atlas_size.x <= p_max_texture_size && atlas_size.y <= p_max_texture_size) { + + Vector source_sizes = sizes; + Vector source_indices; + source_indices.resize(source_sizes.size()); + for (int i = 0; i < source_indices.size(); i++) { + source_indices.write[i] = i; + } + Vector atlas_offsets; + atlas_offsets.resize(source_sizes.size()); + + int slices = 0; + + while (source_sizes.size() > 0) { + + Vector offsets = Geometry::partial_pack_rects(source_sizes, atlas_size); + Vector new_indices; + Vector new_sources; + for (int i = 0; i < offsets.size(); i++) { + Vector3i ofs = offsets[i]; + int sidx = source_indices[i]; + if (ofs.z > 0) { + //valid + ofs.z = slices; + atlas_offsets.write[sidx] = ofs; + } else { + new_indices.push_back(sidx); + new_sources.push_back(source_sizes[i]); + } + } + + source_sizes = new_sources; + source_indices = new_indices; + slices++; + } + + int mem_used = atlas_size.x * atlas_size.y * slices; + if (mem_used < best_atlas_memory) { + best_atlas_size = atlas_size; + best_atlas_offsets = atlas_offsets; + best_atlas_slices = slices; + best_atlas_memory = mem_used; + } + + if (atlas_size.width == atlas_size.height) { + atlas_size.width *= 2; + } else { + atlas_size.height *= 2; + } + } + atlas_size = best_atlas_size; + atlas_slices = best_atlas_slices; + + // apply the offsets and slice to all images, and also blit albedo and emission + albedo_images.resize(atlas_slices); + emission_images.resize(atlas_slices); + + if (p_step_function) { + p_step_function(0.2, TTR("Blitting albedo and emission"), p_bake_userdata, true); + } + + for (int i = 0; i < atlas_slices; i++) { + Ref albedo; + albedo.instance(); + albedo->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBA8); + albedo->set_as_black(); + albedo_images.write[i] = albedo; + + Ref emission; + emission.instance(); + emission->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH); + emission->set_as_black(); + emission_images.write[i] = emission; + } + + //assign uv positions + + for (int m_i = 0; m_i < mesh_instances.size(); m_i++) { + + MeshInstance &mi = mesh_instances.write[m_i]; + mi.offset.x = best_atlas_offsets[m_i].x; + mi.offset.y = best_atlas_offsets[m_i].y; + mi.slice = best_atlas_offsets[m_i].z; + albedo_images.write[mi.slice]->blit_rect(mi.data.albedo_on_uv2, Rect2(Vector2(), Size2i(mi.data.albedo_on_uv2->get_width(), mi.data.albedo_on_uv2->get_height())), mi.offset); + emission_images.write[mi.slice]->blit_rect(mi.data.emission_on_uv2, Rect2(Vector2(), Size2i(mi.data.emission_on_uv2->get_width(), mi.data.emission_on_uv2->get_height())), mi.offset); + } + + return BAKE_OK; +} + +void LightmapperRD::_create_acceleration_structures(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, AABB &bounds, int grid_size, Vector &probe_positions, GenerateProbes p_generate_probes, Vector &slice_triangle_count, Vector &slice_seam_count, RID &vertex_buffer, RID &triangle_buffer, RID &box_buffer, RID &lights_buffer, RID &triangle_cell_indices_buffer, RID &probe_positions_buffer, RID &grid_texture, RID &grid_texture_sdf, RID &seams_buffer, BakeStepFunc p_step_function, void *p_bake_userdata) { + + HashMap vertex_map; + + //fill triangles array and vertex array + LocalVector triangles; + LocalVector vertex_array; + LocalVector box_array; + LocalVector seams; + + slice_triangle_count.resize(atlas_slices); + slice_seam_count.resize(atlas_slices); + + for (int i = 0; i < atlas_slices; i++) { + slice_triangle_count.write[i] = 0; + slice_seam_count.write[i] = 0; + } + + bounds = AABB(); + + for (int m_i = 0; m_i < mesh_instances.size(); m_i++) { + + if (p_step_function) { + float p = float(m_i + 1) / mesh_instances.size() * 0.1; + p_step_function(0.3 + p, vformat(TTR("Plotting mesh into acceleration structure %d/%d"), m_i + 1, mesh_instances.size()), p_bake_userdata, false); + } + + HashMap edges; + + MeshInstance &mi = mesh_instances.write[m_i]; + + Vector2 uv_scale = Vector2(mi.data.albedo_on_uv2->get_width(), mi.data.albedo_on_uv2->get_height()) / Vector2(atlas_size); + Vector2 uv_offset = Vector2(mi.offset) / Vector2(atlas_size); + if (m_i == 0) { + bounds.position = mi.data.points[0]; + } + + for (int i = 0; i < mi.data.points.size(); i += 3) { + + Vector3 vtxs[3] = { mi.data.points[i + 0], mi.data.points[i + 1], mi.data.points[i + 2] }; + Vector2 uvs[3] = { mi.data.uv2[i + 0] * uv_scale + uv_offset, mi.data.uv2[i + 1] * uv_scale + uv_offset, mi.data.uv2[i + 2] * uv_scale + uv_offset }; + Vector3 normal[3] = { mi.data.normal[i + 0], mi.data.normal[i + 1], mi.data.normal[i + 2] }; + + AABB taabb; + Triangle t; + t.slice = mi.slice; + for (int k = 0; k < 3; k++) { + + bounds.expand_to(vtxs[k]); + + Vertex v; + v.position[0] = vtxs[k].x; + v.position[1] = vtxs[k].y; + v.position[2] = vtxs[k].z; + v.uv[0] = uvs[k].x; + v.uv[1] = uvs[k].y; + v.normal_xy[0] = normal[k].x; + v.normal_xy[1] = normal[k].y; + v.normal_z = normal[k].z; + + uint32_t *indexptr = vertex_map.getptr(v); + + if (indexptr) { + t.indices[k] = *indexptr; + } else { + uint32_t new_index = vertex_map.size(); + t.indices[k] = new_index; + vertex_map[v] = new_index; + vertex_array.push_back(v); + } + + if (k == 0) { + taabb.position = vtxs[k]; + } else { + taabb.expand_to(vtxs[k]); + } + } + + //compute seams that will need to be blended later + for (int k = 0; k < 3; k++) { + int n = (k + 1) % 3; + + Edge edge(vtxs[k], vtxs[n], normal[k], normal[n]); + Vector2i edge_indices(t.indices[k], t.indices[n]); + EdgeUV2 uv2(uvs[k], uvs[n], edge_indices); + + if (edge.b == edge.a) { + continue; //degenerate, somehow + } + if (edge.b < edge.a) { + SWAP(edge.a, edge.b); + SWAP(edge.na, edge.nb); + SWAP(uv2.a, uv2.b); + SWAP(edge_indices.x, edge_indices.y); + } + + EdgeUV2 *euv2 = edges.getptr(edge); + if (!euv2) { + edges[edge] = uv2; + } else { + if (*euv2 == uv2) { + continue; // seam shared UV space, no need to blend + } + if (euv2->seam_found) { + continue; //bad geometry + } + + Seam seam; + seam.a = edge_indices; + seam.b = euv2->indices; + seam.slice = mi.slice; + seams.push_back(seam); + slice_seam_count.write[mi.slice]++; + euv2->seam_found = true; + } + } + + Box box; + box.min_bounds[0] = taabb.position.x; + box.min_bounds[1] = taabb.position.y; + box.min_bounds[2] = taabb.position.z; + box.max_bounds[0] = taabb.position.x + MAX(taabb.size.x, 0.0001); + box.max_bounds[1] = taabb.position.y + MAX(taabb.size.y, 0.0001); + box.max_bounds[2] = taabb.position.z + MAX(taabb.size.z, 0.0001); + box.pad0 = box.pad1 = 0; //make valgrind not complain + box_array.push_back(box); + + triangles.push_back(t); + slice_triangle_count.write[t.slice]++; + } + } + + //also consider probe positions for bounds + for (int i = 0; i < probe_positions.size(); i++) { + Vector3 pp(probe_positions[i].position[0], probe_positions[i].position[1], probe_positions[i].position[2]); + bounds.expand_to(pp); + } + bounds.grow_by(0.1); //grow a bit to avoid numerical error + + triangles.sort(); //sort by slice + seams.sort(); + + if (p_step_function) { + p_step_function(0.4, TTR("Optimizing acceleration structure"), p_bake_userdata, true); + } + + //fill list of triangles in grid + LocalVector triangle_sort; + for (uint32_t i = 0; i < triangles.size(); i++) { + + const Triangle &t = triangles[i]; + Vector3 face[3] = { + Vector3(vertex_array[t.indices[0]].position[0], vertex_array[t.indices[0]].position[1], vertex_array[t.indices[0]].position[2]), + Vector3(vertex_array[t.indices[1]].position[0], vertex_array[t.indices[1]].position[1], vertex_array[t.indices[1]].position[2]), + Vector3(vertex_array[t.indices[2]].position[0], vertex_array[t.indices[2]].position[1], vertex_array[t.indices[2]].position[2]) + }; + _plot_triangle_into_triangle_index_list(grid_size, Vector3i(), bounds, face, i, triangle_sort, grid_size); + } + //sort it + triangle_sort.sort(); + + Vector triangle_indices; + triangle_indices.resize(triangle_sort.size()); + Vector grid_indices; + grid_indices.resize(grid_size * grid_size * grid_size * 2); + zeromem(grid_indices.ptrw(), grid_indices.size() * sizeof(uint32_t)); + Vector solid; + solid.resize(grid_size * grid_size * grid_size); + zeromem(solid.ptrw(), solid.size() * sizeof(bool)); + + { + uint32_t *tiw = triangle_indices.ptrw(); + uint32_t last_cell = 0xFFFFFFFF; + uint32_t *giw = grid_indices.ptrw(); + bool *solidw = solid.ptrw(); + for (uint32_t i = 0; i < triangle_sort.size(); i++) { + uint32_t cell = triangle_sort[i].cell_index; + if (cell != last_cell) { + //cell changed, update pointer to indices + giw[cell * 2 + 1] = i; + last_cell = cell; + solidw[cell] = true; + } + tiw[i] = triangle_sort[i].triangle_index; + giw[cell * 2]++; //update counter + last_cell = cell; + } + } +#if 0 + for (int i = 0; i < grid_size; i++) { + for (int j = 0; j < grid_size; j++) { + for (int k = 0; k < grid_size; k++) { + uint32_t index = i * (grid_size * grid_size) + j * grid_size + k; + grid_indices.write[index * 2] = float(i) / grid_size * 255; + grid_indices.write[index * 2 + 1] = float(j) / grid_size * 255; + } + } + } +#endif + +#if 0 + for (int i = 0; i < grid_size; i++) { + Vector grid_usage; + grid_usage.resize(grid_size * grid_size); + for (int j = 0; j < grid_usage.size(); j++) { + uint32_t ofs = i * grid_size * grid_size + j; + uint32_t count = grid_indices[ofs * 2]; + grid_usage.write[j] = count > 0 ? 255 : 0; + } + + Ref img; + img.instance(); + img->create(grid_size, grid_size, false, Image::FORMAT_L8, grid_usage); + img->save_png("res://grid_layer_" + itos(1000 + i).substr(1, 3) + ".png"); + } +#endif + if (p_step_function) { + p_step_function(0.45, TTR("Generating Signed Distance Field"), p_bake_userdata, true); + } + + //generate SDF for raytracing + Vector euclidean_pos = Geometry::generate_edf(solid, Vector3i(grid_size, grid_size, grid_size), false); + Vector euclidean_neg = Geometry::generate_edf(solid, Vector3i(grid_size, grid_size, grid_size), true); + Vector sdf8 = Geometry::generate_sdf8(euclidean_pos, euclidean_neg); + + /*****************************/ + /*** CREATE GPU STRUCTURES ***/ + /*****************************/ + + lights.sort(); + + Vector seam_buffer_vec; + seam_buffer_vec.resize(seams.size() * 2); + for (uint32_t i = 0; i < seams.size(); i++) { + seam_buffer_vec.write[i * 2 + 0] = seams[i].a; + seam_buffer_vec.write[i * 2 + 1] = seams[i].b; + } + + { //buffers + Vector vb = vertex_array.to_byte_array(); + vertex_buffer = rd->storage_buffer_create(vb.size(), vb); + + Vector tb = triangles.to_byte_array(); + triangle_buffer = rd->storage_buffer_create(tb.size(), tb); + + Vector bb = box_array.to_byte_array(); + box_buffer = rd->storage_buffer_create(bb.size(), bb); + + Vector tib = triangle_indices.to_byte_array(); + triangle_cell_indices_buffer = rd->storage_buffer_create(tib.size(), tib); + + Vector lb = lights.to_byte_array(); + if (lb.size() == 0) { + lb.resize(sizeof(Light)); //even if no lights, the buffer must exist + } + lights_buffer = rd->storage_buffer_create(lb.size(), lb); + + Vector sb = seam_buffer_vec.to_byte_array(); + if (sb.size() == 0) { + sb.resize(sizeof(Vector2i) * 2); //even if no seams, the buffer must exist + } + seams_buffer = rd->storage_buffer_create(sb.size(), sb); + + Vector pb = probe_positions.to_byte_array(); + if (pb.size() == 0) { + pb.resize(sizeof(Probe)); + } + probe_positions_buffer = rd->storage_buffer_create(pb.size(), pb); + } + + { //grid + + RD::TextureFormat tf; + tf.width = grid_size; + tf.height = grid_size; + tf.depth = grid_size; + tf.type = RD::TEXTURE_TYPE_3D; + tf.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + + Vector> texdata; + texdata.resize(1); + //grid and indices + tf.format = RD::DATA_FORMAT_R32G32_UINT; + texdata.write[0] = grid_indices.to_byte_array(); + grid_texture = rd->texture_create(tf, RD::TextureView(), texdata); + //sdf + tf.format = RD::DATA_FORMAT_R8_SNORM; + texdata.write[0] = sdf8.to_byte_array(); + grid_texture_sdf = rd->texture_create(tf, RD::TextureView(), texdata); + } +} + +void LightmapperRD::_raster_geometry(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, int grid_size, AABB bounds, float p_bias, Vector slice_triangle_count, RID position_tex, RID unocclude_tex, RID normal_tex, RID raster_depth_buffer, RID rasterize_shader, RID raster_base_uniform) { + + Vector framebuffers; + + for (int i = 0; i < atlas_slices; i++) { + RID slice_pos_tex = rd->texture_create_shared_from_slice(RD::TextureView(), position_tex, i, 0); + RID slice_unoc_tex = rd->texture_create_shared_from_slice(RD::TextureView(), unocclude_tex, i, 0); + RID slice_norm_tex = rd->texture_create_shared_from_slice(RD::TextureView(), normal_tex, i, 0); + Vector fb; + fb.push_back(slice_pos_tex); + fb.push_back(slice_norm_tex); + fb.push_back(slice_unoc_tex); + fb.push_back(raster_depth_buffer); + framebuffers.push_back(rd->framebuffer_create(fb)); + } + + RD::PipelineDepthStencilState ds; + ds.enable_depth_test = true; + ds.enable_depth_write = true; + ds.depth_compare_operator = RD::COMPARE_OP_LESS; //so it does render same pixel twice + + RID raster_pipeline = rd->render_pipeline_create(rasterize_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_TRIANGLES, RD::PipelineRasterizationState(), RD::PipelineMultisampleState(), ds, RD::PipelineColorBlendState::create_disabled(3), 0); + RID raster_pipeline_wire; + { + + RD::PipelineRasterizationState rw; + rw.wireframe = true; + raster_pipeline_wire = rd->render_pipeline_create(rasterize_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_TRIANGLES, rw, RD::PipelineMultisampleState(), ds, RD::PipelineColorBlendState::create_disabled(3), 0); + } + + uint32_t triangle_offset = 0; + Vector clear_colors; + clear_colors.push_back(Color(0, 0, 0, 0)); + clear_colors.push_back(Color(0, 0, 0, 0)); + clear_colors.push_back(Color(0, 0, 0, 0)); + + for (int i = 0; i < atlas_slices; i++) { + + RasterPushConstant raster_push_constant; + raster_push_constant.atlas_size[0] = atlas_size.x; + raster_push_constant.atlas_size[1] = atlas_size.y; + raster_push_constant.base_triangle = triangle_offset; + raster_push_constant.to_cell_offset[0] = bounds.position.x; + raster_push_constant.to_cell_offset[1] = bounds.position.y; + raster_push_constant.to_cell_offset[2] = bounds.position.z; + raster_push_constant.bias = p_bias; + raster_push_constant.to_cell_size[0] = (1.0 / bounds.size.x) * float(grid_size); + raster_push_constant.to_cell_size[1] = (1.0 / bounds.size.y) * float(grid_size); + raster_push_constant.to_cell_size[2] = (1.0 / bounds.size.z) * float(grid_size); + raster_push_constant.grid_size[0] = grid_size; + raster_push_constant.grid_size[1] = grid_size; + raster_push_constant.grid_size[2] = grid_size; + raster_push_constant.uv_offset[0] = 0; + raster_push_constant.uv_offset[1] = 0; + + RD::DrawListID draw_list = rd->draw_list_begin(framebuffers[i], RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_DISCARD, clear_colors); + //draw opaque + rd->draw_list_bind_render_pipeline(draw_list, raster_pipeline); + rd->draw_list_bind_uniform_set(draw_list, raster_base_uniform, 0); + rd->draw_list_set_push_constant(draw_list, &raster_push_constant, sizeof(RasterPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_triangle_count[i] * 3); + //draw wire + rd->draw_list_bind_render_pipeline(draw_list, raster_pipeline_wire); + rd->draw_list_bind_uniform_set(draw_list, raster_base_uniform, 0); + rd->draw_list_set_push_constant(draw_list, &raster_push_constant, sizeof(RasterPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_triangle_count[i] * 3); + + rd->draw_list_end(); + + triangle_offset += slice_triangle_count[i]; + } +} + +LightmapperRD::BakeError LightmapperRD::bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function, void *p_bake_userdata) { + + if (p_step_function) { + p_step_function(0.0, TTR("Begin Bake"), p_bake_userdata, true); + } + bake_textures.clear(); + int grid_size = 128; + + /* STEP 1: Fetch material textures and compute the bounds */ + + AABB bounds; + Size2i atlas_size; + int atlas_slices; + Vector> albedo_images; + Vector> emission_images; + + BakeError bake_error = _blit_meshes_into_atlas(p_max_texture_size, albedo_images, emission_images, bounds, atlas_size, atlas_slices, p_step_function, p_bake_userdata); + if (bake_error != BAKE_OK) { + return bake_error; + } + +#ifdef DEBUG_TEXTURES + for (int i = 0; i < atlas_slices; i++) { + albedo_images[i]->save_png("res://0_albedo_" + itos(i) + ".png"); + emission_images[i]->save_png("res://0_emission_" + itos(i) + ".png"); + } +#endif + + RenderingDevice *rd = RenderingDevice::get_singleton()->create_local_device(); + + RID albedo_array_tex; + RID emission_array_tex; + RID normal_tex; + RID position_tex; + RID unocclude_tex; + RID light_source_tex; + RID light_dest_tex; + RID light_accum_tex; + RID light_accum_tex2; + RID light_primary_dynamic_tex; + RID light_environment_tex; + +#define FREE_TEXTURES \ + rd->free(albedo_array_tex); \ + rd->free(emission_array_tex); \ + rd->free(normal_tex); \ + rd->free(position_tex); \ + rd->free(unocclude_tex); \ + rd->free(light_source_tex); \ + rd->free(light_accum_tex2); \ + rd->free(light_accum_tex); \ + rd->free(light_primary_dynamic_tex); \ + rd->free(light_environment_tex); + + { // create all textures + + Vector> albedo_data; + Vector> emission_data; + for (int i = 0; i < atlas_slices; i++) { + albedo_data.push_back(albedo_images[i]->get_data()); + emission_data.push_back(emission_images[i]->get_data()); + } + + RD::TextureFormat tf; + tf.width = atlas_size.width; + tf.height = atlas_size.height; + tf.array_layers = atlas_slices; + tf.type = RD::TEXTURE_TYPE_2D_ARRAY; + tf.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + tf.format = RD::DATA_FORMAT_R8G8B8A8_UNORM; + + albedo_array_tex = rd->texture_create(tf, RD::TextureView(), albedo_data); + + tf.format = RD::DATA_FORMAT_R16G16B16A16_SFLOAT; + + emission_array_tex = rd->texture_create(tf, RD::TextureView(), emission_data); + + //this will be rastered to + tf.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_COLOR_ATTACHMENT_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT | RD::TEXTURE_USAGE_STORAGE_BIT; + normal_tex = rd->texture_create(tf, RD::TextureView()); + tf.format = RD::DATA_FORMAT_R32G32B32A32_SFLOAT; + position_tex = rd->texture_create(tf, RD::TextureView()); + unocclude_tex = rd->texture_create(tf, RD::TextureView()); + + tf.format = RD::DATA_FORMAT_R16G16B16A16_SFLOAT; + tf.usage_bits = RD::TEXTURE_USAGE_COLOR_ATTACHMENT_BIT | RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_STORAGE_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT | RD::TEXTURE_USAGE_CAN_COPY_TO_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + + light_source_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_source_tex, Color(0, 0, 0, 0), 0, 1, 0, atlas_slices); + light_primary_dynamic_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_primary_dynamic_tex, Color(0, 0, 0, 0), 0, 1, 0, atlas_slices); + + if (p_bake_sh) { + tf.array_layers *= 4; + } + light_accum_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_accum_tex, Color(0, 0, 0, 0), 0, 1, 0, tf.array_layers); + light_dest_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_dest_tex, Color(0, 0, 0, 0), 0, 1, 0, tf.array_layers); + light_accum_tex2 = light_dest_tex; + + //env + { + Ref panorama_tex; + if (p_environment_panorama.is_valid()) { + panorama_tex = p_environment_panorama; + panorama_tex->convert(Image::FORMAT_RGBAF); + } else { + panorama_tex.instance(); + panorama_tex->create(8, 8, false, Image::FORMAT_RGBAF); + for (int i = 0; i < 8; i++) { + for (int j = 0; j < 8; j++) { + panorama_tex->set_pixel(i, j, Color(0, 0, 0, 1)); + } + } + } + + RD::TextureFormat tfp; + tfp.width = panorama_tex->get_width(); + tfp.height = panorama_tex->get_height(); + tfp.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + tfp.format = RD::DATA_FORMAT_R32G32B32A32_SFLOAT; + + Vector> tdata; + tdata.push_back(panorama_tex->get_data()); + light_environment_tex = rd->texture_create(tfp, RD::TextureView(), tdata); + +#ifdef DEBUG_TEXTURES + panorama_tex->convert(Image::FORMAT_RGB8); + panorama_tex->save_png("res://0_panorama.png"); +#endif + } + } + + /* STEP 2: create the acceleration structure for the GPU*/ + + Vector slice_triangle_count; + RID vertex_buffer; + RID triangle_buffer; + RID box_buffer; + RID lights_buffer; + RID triangle_cell_indices_buffer; + RID grid_texture; + RID grid_texture_sdf; + RID seams_buffer; + RID probe_positions_buffer; + + Vector slice_seam_count; + +#define FREE_BUFFERS \ + rd->free(vertex_buffer); \ + rd->free(triangle_buffer); \ + rd->free(box_buffer); \ + rd->free(lights_buffer); \ + rd->free(triangle_cell_indices_buffer); \ + rd->free(grid_texture); \ + rd->free(grid_texture_sdf); \ + rd->free(seams_buffer); \ + rd->free(probe_positions_buffer); + + _create_acceleration_structures(rd, atlas_size, atlas_slices, bounds, grid_size, probe_positions, p_generate_probes, slice_triangle_count, slice_seam_count, vertex_buffer, triangle_buffer, box_buffer, lights_buffer, triangle_cell_indices_buffer, probe_positions_buffer, grid_texture, grid_texture_sdf, seams_buffer, p_step_function, p_bake_userdata); + + if (p_step_function) { + p_step_function(0.47, TTR("Preparing shaders"), p_bake_userdata, true); + } + + //shaders + Ref raster_shader; + raster_shader.instance(); + Error err = raster_shader->parse_versions_from_text(lm_raster_shader_glsl); + if (err != OK) { + raster_shader->print_errors("raster_shader"); + + FREE_TEXTURES + FREE_BUFFERS + + memdelete(rd); + } + ERR_FAIL_COND_V(err != OK, BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + RID rasterize_shader = rd->shader_create_from_bytecode(raster_shader->get_bytecode()); + + ERR_FAIL_COND_V(rasterize_shader.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //this is a bug check, though, should not happen + + RID sampler; + { + RD::SamplerState s; + s.mag_filter = RD::SAMPLER_FILTER_LINEAR; + s.min_filter = RD::SAMPLER_FILTER_LINEAR; + s.max_lod = 0; + + sampler = rd->sampler_create(s); + } + + Vector base_uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 1; + u.ids.push_back(vertex_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 2; + u.ids.push_back(triangle_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 3; + u.ids.push_back(box_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 4; + u.ids.push_back(triangle_cell_indices_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 5; + u.ids.push_back(lights_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 6; + u.ids.push_back(seams_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 7; + u.ids.push_back(probe_positions_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 8; + u.ids.push_back(grid_texture); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 9; + u.ids.push_back(grid_texture_sdf); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 10; + u.ids.push_back(albedo_array_tex); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 11; + u.ids.push_back(emission_array_tex); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_SAMPLER; + u.binding = 12; + u.ids.push_back(sampler); + base_uniforms.push_back(u); + } + } + + RID raster_base_uniform = rd->uniform_set_create(base_uniforms, rasterize_shader, 0); + RID raster_depth_buffer; + { + RD::TextureFormat tf; + tf.width = atlas_size.width; + tf.height = atlas_size.height; + tf.depth = 1; + tf.type = RD::TEXTURE_TYPE_2D; + tf.usage_bits = RD::TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT; + tf.format = RD::DATA_FORMAT_D32_SFLOAT; + + raster_depth_buffer = rd->texture_create(tf, RD::TextureView()); + } + + rd->submit(); + rd->sync(); + + /* STEP 3: Raster the geometry to UV2 coords in the atlas textures GPU*/ + + _raster_geometry(rd, atlas_size, atlas_slices, grid_size, bounds, p_bias, slice_triangle_count, position_tex, unocclude_tex, normal_tex, raster_depth_buffer, rasterize_shader, raster_base_uniform); + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices; i++) { + Vector s = rd->texture_get_data(position_tex, i); + Ref img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAF, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://1_position_" + itos(i) + ".png"); + + s = rd->texture_get_data(normal_tex, i); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://1_normal_" + itos(i) + ".png"); + } +#endif + +#define FREE_RASTER_RESOURCES \ + rd->free(rasterize_shader); \ + rd->free(sampler); \ + rd->free(raster_depth_buffer); + + /* Plot direct light */ + + Ref compute_shader; + compute_shader.instance(); + err = compute_shader->parse_versions_from_text(lm_compute_shader_glsl, p_bake_sh ? "\n#define USE_SH_LIGHTMAPS\n" : ""); + if (err != OK) { + + FREE_TEXTURES + FREE_BUFFERS + FREE_RASTER_RESOURCES + memdelete(rd); + compute_shader->print_errors("compute_shader"); + } + ERR_FAIL_COND_V(err != OK, BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + //unoccluder + RID compute_shader_unocclude = rd->shader_create_from_bytecode(compute_shader->get_bytecode("unocclude")); + ERR_FAIL_COND_V(compute_shader_unocclude.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); // internal check, should not happen + RID compute_shader_unocclude_pipeline = rd->compute_pipeline_create(compute_shader_unocclude); + + //direct light + RID compute_shader_primary = rd->shader_create_from_bytecode(compute_shader->get_bytecode("primary")); + ERR_FAIL_COND_V(compute_shader_primary.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); // internal check, should not happen + RID compute_shader_primary_pipeline = rd->compute_pipeline_create(compute_shader_primary); + + //indirect light + RID compute_shader_secondary = rd->shader_create_from_bytecode(compute_shader->get_bytecode("secondary")); + ERR_FAIL_COND_V(compute_shader_secondary.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //internal check, should not happen + RID compute_shader_secondary_pipeline = rd->compute_pipeline_create(compute_shader_secondary); + + //dilate + RID compute_shader_dilate = rd->shader_create_from_bytecode(compute_shader->get_bytecode("dilate")); + ERR_FAIL_COND_V(compute_shader_dilate.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //internal check, should not happen + RID compute_shader_dilate_pipeline = rd->compute_pipeline_create(compute_shader_dilate); + + //dilate + RID compute_shader_light_probes = rd->shader_create_from_bytecode(compute_shader->get_bytecode("light_probes")); + ERR_FAIL_COND_V(compute_shader_light_probes.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //internal check, should not happen + RID compute_shader_light_probes_pipeline = rd->compute_pipeline_create(compute_shader_light_probes); + + RID compute_base_uniform_set = rd->uniform_set_create(base_uniforms, compute_shader_primary, 0); + +#define FREE_COMPUTE_RESOURCES \ + rd->free(compute_shader_unocclude); \ + rd->free(compute_shader_primary); \ + rd->free(compute_shader_secondary); \ + rd->free(compute_shader_dilate); \ + rd->free(compute_shader_light_probes); + + PushConstant push_constant; + { + //set defaults + push_constant.atlas_size[0] = atlas_size.width; + push_constant.atlas_size[1] = atlas_size.height; + push_constant.world_size[0] = bounds.size.x; + push_constant.world_size[1] = bounds.size.y; + push_constant.world_size[2] = bounds.size.z; + push_constant.to_cell_offset[0] = bounds.position.x; + push_constant.to_cell_offset[1] = bounds.position.y; + push_constant.to_cell_offset[2] = bounds.position.z; + push_constant.bias = p_bias; + push_constant.to_cell_size[0] = (1.0 / bounds.size.x) * float(grid_size); + push_constant.to_cell_size[1] = (1.0 / bounds.size.y) * float(grid_size); + push_constant.to_cell_size[2] = (1.0 / bounds.size.z) * float(grid_size); + push_constant.light_count = lights.size(); + push_constant.grid_size = grid_size; + push_constant.atlas_slice = 0; + push_constant.region_ofs[0] = 0; + push_constant.region_ofs[1] = 0; + push_constant.environment_xform[0] = p_environment_transform.elements[0][0]; + push_constant.environment_xform[1] = p_environment_transform.elements[1][0]; + push_constant.environment_xform[2] = p_environment_transform.elements[2][0]; + push_constant.environment_xform[3] = 0; + push_constant.environment_xform[4] = p_environment_transform.elements[0][1]; + push_constant.environment_xform[5] = p_environment_transform.elements[1][1]; + push_constant.environment_xform[6] = p_environment_transform.elements[2][1]; + push_constant.environment_xform[7] = 0; + push_constant.environment_xform[8] = p_environment_transform.elements[0][2]; + push_constant.environment_xform[9] = p_environment_transform.elements[1][2]; + push_constant.environment_xform[10] = p_environment_transform.elements[2][2]; + push_constant.environment_xform[11] = 0; + } + + Vector3i group_size((atlas_size.x - 1) / 8 + 1, (atlas_size.y - 1) / 8 + 1, 1); + rd->submit(); + rd->sync(); + + if (p_step_function) { + p_step_function(0.49, TTR("Un-occluding geometry"), p_bake_userdata, true); + } + + /* UNOCCLUDE */ + { + + Vector uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(position_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 1; + u.ids.push_back(unocclude_tex); //will be unused + uniforms.push_back(u); + } + } + + RID unocclude_uniform_set = rd->uniform_set_create(uniforms, compute_shader_unocclude, 1); + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_unocclude_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, unocclude_uniform_set, 1); + + for (int i = 0; i < atlas_slices; i++) { + push_constant.atlas_slice = i; + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + //no barrier, let them run all together + } + rd->compute_list_end(); //done + } + + if (p_step_function) { + p_step_function(0.5, TTR("Plot direct lighting"), p_bake_userdata, true); + } + + /* PRIMARY (direct) LIGHT PASS */ + { + + Vector uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(light_source_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_dest_tex); //will be unused + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 2; + u.ids.push_back(position_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 3; + u.ids.push_back(normal_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 4; + u.ids.push_back(light_accum_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 5; + u.ids.push_back(light_primary_dynamic_tex); + uniforms.push_back(u); + } + } + + RID light_uniform_set = rd->uniform_set_create(uniforms, compute_shader_primary, 1); + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_primary_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, light_uniform_set, 1); + + for (int i = 0; i < atlas_slices; i++) { + push_constant.atlas_slice = i; + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + //no barrier, let them run all together + } + rd->compute_list_end(); //done + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices; i++) { + Vector s = rd->texture_get_data(light_source_tex, i); + Ref img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://2_light_primary_" + itos(i) + ".png"); + } +#endif + + /* SECONDARY (indirect) LIGHT PASS(ES) */ + if (p_step_function) { + p_step_function(0.6, TTR("Integrate indirect lighting"), p_bake_userdata, true); + } + + if (p_bounces > 0) { + + Vector uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(light_dest_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_source_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 2; + u.ids.push_back(position_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 3; + u.ids.push_back(normal_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 4; + u.ids.push_back(light_accum_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 5; + u.ids.push_back(unocclude_tex); //reuse unocclude tex + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 6; + u.ids.push_back(light_environment_tex); //reuse unocclude tex + uniforms.push_back(u); + } + } + + RID secondary_uniform_set[2]; + secondary_uniform_set[0] = rd->uniform_set_create(uniforms, compute_shader_secondary, 1); + uniforms.write[0].ids.write[0] = light_source_tex; + uniforms.write[1].ids.write[0] = light_dest_tex; + secondary_uniform_set[1] = rd->uniform_set_create(uniforms, compute_shader_secondary, 1); + + switch (p_quality) { + case BAKE_QUALITY_LOW: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/low_quality_ray_count"); + } break; + case BAKE_QUALITY_MEDIUM: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/medium_quality_ray_count"); + } break; + case BAKE_QUALITY_HIGH: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/high_quality_ray_count"); + } break; + case BAKE_QUALITY_ULTRA: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/ultra_quality_ray_count"); + } break; + } + + push_constant.ray_count = CLAMP(push_constant.ray_count, 16, 8192); + + int max_region_size = nearest_power_of_2_templated(int(GLOBAL_GET("rendering/gpu_lightmapper/performance/region_size"))); + int max_rays = GLOBAL_GET("rendering/gpu_lightmapper/performance/max_rays_per_pass"); + + int x_regions = (atlas_size.width - 1) / max_region_size + 1; + int y_regions = (atlas_size.height - 1) / max_region_size + 1; + int ray_iterations = (push_constant.ray_count - 1) / max_rays + 1; + + rd->submit(); + rd->sync(); + + for (int b = 0; b < p_bounces; b++) { + int count = 0; + if (b > 0) { + SWAP(light_source_tex, light_dest_tex); + SWAP(secondary_uniform_set[0], secondary_uniform_set[1]); + } + + for (int s = 0; s < atlas_slices; s++) { + push_constant.atlas_slice = s; + + for (int i = 0; i < x_regions; i++) { + for (int j = 0; j < y_regions; j++) { + + int x = i * max_region_size; + int y = j * max_region_size; + int w = MIN((i + 1) * max_region_size, atlas_size.width) - x; + int h = MIN((j + 1) * max_region_size, atlas_size.height) - y; + + push_constant.region_ofs[0] = x; + push_constant.region_ofs[1] = y; + + group_size = Vector3i((w - 1) / 8 + 1, (h - 1) / 8 + 1, 1); + + for (int k = 0; k < ray_iterations; k++) { + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_secondary_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, secondary_uniform_set[0], 1); + + push_constant.ray_from = k * max_rays; + push_constant.ray_to = MIN((k + 1) * max_rays, int32_t(push_constant.ray_count)); + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + + rd->compute_list_end(); //done + rd->submit(); + rd->sync(); + + count++; + if (p_step_function) { + int total = (atlas_slices * x_regions * y_regions * ray_iterations); + int percent = count * 100 / total; + float p = float(count) / total * 0.1; + p_step_function(0.6 + p, vformat(TTR("Bounce %d/%d: Integrate indirect lighting %d%%"), b + 1, p_bounces, percent), p_bake_userdata, false); + } + } + } + } + } + } + } + + /* LIGHPROBES */ + + RID light_probe_buffer; + + if (probe_positions.size()) { + + light_probe_buffer = rd->storage_buffer_create(sizeof(float) * 4 * 9 * probe_positions.size()); + + if (p_step_function) { + p_step_function(0.7, TTR("Baking lightprobes"), p_bake_userdata, true); + } + + Vector uniforms; + { + + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 0; + u.ids.push_back(light_probe_buffer); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_dest_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 2; + u.ids.push_back(light_primary_dynamic_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 3; + u.ids.push_back(light_environment_tex); + uniforms.push_back(u); + } + } + RID light_probe_uniform_set = rd->uniform_set_create(uniforms, compute_shader_light_probes, 1); + + switch (p_quality) { + case BAKE_QUALITY_LOW: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/low_quality_probe_ray_count"); + } break; + case BAKE_QUALITY_MEDIUM: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/medium_quality_probe_ray_count"); + } break; + case BAKE_QUALITY_HIGH: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/high_quality_probe_ray_count"); + } break; + case BAKE_QUALITY_ULTRA: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/ultra_quality_probe_ray_count"); + } break; + } + + push_constant.atlas_size[0] = probe_positions.size(); + push_constant.ray_count = CLAMP(push_constant.ray_count, 16, 8192); + + int max_rays = GLOBAL_GET("rendering/gpu_lightmapper/performance/max_rays_per_probe_pass"); + int ray_iterations = (push_constant.ray_count - 1) / max_rays + 1; + + for (int i = 0; i < ray_iterations; i++) { + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_light_probes_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, light_probe_uniform_set, 1); + + push_constant.ray_from = i * max_rays; + push_constant.ray_to = MIN((i + 1) * max_rays, int32_t(push_constant.ray_count)); + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, (probe_positions.size() - 1) / 64 + 1, 1, 1); + + rd->compute_list_end(); //done + rd->submit(); + rd->sync(); + + if (p_step_function) { + int percent = i * 100 / ray_iterations; + float p = float(i) / ray_iterations * 0.1; + p_step_function(0.7 + p, vformat(TTR("Integrating light probes %d%%"), percent), p_bake_userdata, false); + } + } + + push_constant.atlas_size[0] = atlas_size.x; //restore + } + +#if 0 + for (int i = 0; i < probe_positions.size(); i++) { + Ref img; + img.instance(); + img->create(6, 4, false, Image::FORMAT_RGB8); + for (int j = 0; j < 6; j++) { + Vector s = rd->texture_get_data(lightprobe_tex, i * 6 + j); + Ref img2; + img2.instance(); + img2->create(2, 2, false, Image::FORMAT_RGBAF, s); + img2->convert(Image::FORMAT_RGB8); + img->blit_rect(img2, Rect2(0, 0, 2, 2), Point2((j % 3) * 2, (j / 3) * 2)); + } + img->save_png("res://3_light_probe_" + itos(i) + ".png"); + } +#endif + + /* DENOISE */ + + if (p_use_denoiser) { + if (p_step_function) { + p_step_function(0.8, TTR("Denoising"), p_bake_userdata, true); + } + + Ref denoiser = LightmapDenoiser::create(); + if (denoiser.is_valid()) { + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector s = rd->texture_get_data(light_accum_tex, i); + Ref img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + + Ref denoised = denoiser->denoise_image(img); + if (denoised != img) { + denoised->convert(Image::FORMAT_RGBAH); + Vector ds = denoised->get_data(); + denoised.unref(); //avoid copy on write + { //restore alpha + uint32_t count = s.size() / 2; //uint16s + const uint16_t *src = (const uint16_t *)s.ptr(); + uint16_t *dst = (uint16_t *)ds.ptrw(); + for (uint32_t j = 0; j < count; j += 4) { + dst[j + 3] = src[j + 3]; + } + } + rd->texture_update(light_accum_tex, i, ds, true); + } + } + } + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector s = rd->texture_get_data(light_accum_tex, i); + Ref img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://4_light_secondary_" + itos(i) + ".png"); + } +#endif + + /* DILATE LIGHTMAP */ + { + + SWAP(light_accum_tex, light_accum_tex2); + + Vector uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(light_accum_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_accum_tex2); + uniforms.push_back(u); + } + } + + RID dilate_uniform_set = rd->uniform_set_create(uniforms, compute_shader_dilate, 1); + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_dilate_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, dilate_uniform_set, 1); + push_constant.region_ofs[0] = 0; + push_constant.region_ofs[1] = 0; + group_size = Vector3i((atlas_size.x - 1) / 8 + 1, (atlas_size.y - 1) / 8 + 1, 1); //restore group size + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + push_constant.atlas_slice = i; + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + //no barrier, let them run all together + } + rd->compute_list_end(); + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector s = rd->texture_get_data(light_accum_tex, i); + Ref img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://5_dilated_" + itos(i) + ".png"); + } +#endif + + /* BLEND SEAMS */ + //shaders + Ref blendseams_shader; + blendseams_shader.instance(); + err = blendseams_shader->parse_versions_from_text(lm_blendseams_shader_glsl); + if (err != OK) { + FREE_TEXTURES + FREE_BUFFERS + FREE_RASTER_RESOURCES + FREE_COMPUTE_RESOURCES + memdelete(rd); + blendseams_shader->print_errors("blendseams_shader"); + } + ERR_FAIL_COND_V(err != OK, BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + RID blendseams_line_raster_shader = rd->shader_create_from_bytecode(blendseams_shader->get_bytecode("lines")); + + ERR_FAIL_COND_V(blendseams_line_raster_shader.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + RID blendseams_triangle_raster_shader = rd->shader_create_from_bytecode(blendseams_shader->get_bytecode("triangles")); + + ERR_FAIL_COND_V(blendseams_triangle_raster_shader.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + +#define FREE_BLENDSEAMS_RESOURCES \ + rd->free(blendseams_line_raster_shader); \ + rd->free(blendseams_triangle_raster_shader); + + { + + //pre copy + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + rd->texture_copy(light_accum_tex, light_accum_tex2, Vector3(), Vector3(), Vector3(atlas_size.width, atlas_size.height, 1), 0, 0, i, i, true); + } + + Vector framebuffers; + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + RID slice_tex = rd->texture_create_shared_from_slice(RD::TextureView(), light_accum_tex, i, 0); + Vector fb; + fb.push_back(slice_tex); + fb.push_back(raster_depth_buffer); + framebuffers.push_back(rd->framebuffer_create(fb)); + } + + Vector uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 0; + u.ids.push_back(light_accum_tex2); + uniforms.push_back(u); + } + } + + RID blendseams_raster_uniform = rd->uniform_set_create(uniforms, blendseams_line_raster_shader, 1); + + bool debug = false; + RD::PipelineColorBlendState bs = RD::PipelineColorBlendState::create_blend(1); + bs.attachments.write[0].src_alpha_blend_factor = RD::BLEND_FACTOR_ZERO; + bs.attachments.write[0].dst_alpha_blend_factor = RD::BLEND_FACTOR_ONE; + + RD::PipelineDepthStencilState ds; + ds.enable_depth_test = true; + ds.enable_depth_write = true; + ds.depth_compare_operator = RD::COMPARE_OP_LESS; //so it does not render same pixel twice, this avoids wrong blending + + RID blendseams_line_raster_pipeline = rd->render_pipeline_create(blendseams_line_raster_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_LINES, RD::PipelineRasterizationState(), RD::PipelineMultisampleState(), ds, bs, 0); + RID blendseams_triangle_raster_pipeline = rd->render_pipeline_create(blendseams_triangle_raster_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_TRIANGLES, RD::PipelineRasterizationState(), RD::PipelineMultisampleState(), ds, bs, 0); + + uint32_t seam_offset = 0; + uint32_t triangle_offset = 0; + + Vector clear_colors; + clear_colors.push_back(Color(0, 0, 0, 1)); + for (int i = 0; i < atlas_slices; i++) { + + int subslices = (p_bake_sh ? 4 : 1); + for (int k = 0; k < subslices; k++) { + + RasterSeamsPushConstant seams_push_constant; + seams_push_constant.slice = uint32_t(i * subslices + k); + seams_push_constant.debug = debug; + + RD::DrawListID draw_list = rd->draw_list_begin(framebuffers[i], RD::INITIAL_ACTION_KEEP, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_DISCARD, clear_colors); + + rd->draw_list_bind_uniform_set(draw_list, raster_base_uniform, 0); + rd->draw_list_bind_uniform_set(draw_list, blendseams_raster_uniform, 1); + + const int uv_offset_count = 9; + static const Vector3 uv_offsets[uv_offset_count] = { + Vector3(0, 0, 0.5), //using zbuffer, so go inwards-outwards + Vector3(0, 1, 0.2), + Vector3(0, -1, 0.2), + Vector3(1, 0, 0.2), + Vector3(-1, 0, 0.2), + Vector3(-1, -1, 0.1), + Vector3(1, -1, 0.1), + Vector3(1, 1, 0.1), + Vector3(-1, 1, 0.1), + }; + + /* step 1 use lines to blend the edges */ + { + seams_push_constant.base_index = seam_offset; + rd->draw_list_bind_render_pipeline(draw_list, blendseams_line_raster_pipeline); + seams_push_constant.uv_offset[0] = uv_offsets[0].x / float(atlas_size.width); + seams_push_constant.uv_offset[1] = uv_offsets[0].y / float(atlas_size.height); + seams_push_constant.blend = uv_offsets[0].z; + + rd->draw_list_set_push_constant(draw_list, &seams_push_constant, sizeof(RasterSeamsPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_seam_count[i] * 4); + } + + /* step 2 use triangles to mask the interior */ + + { + seams_push_constant.base_index = triangle_offset; + rd->draw_list_bind_render_pipeline(draw_list, blendseams_triangle_raster_pipeline); + seams_push_constant.blend = 0; //do not draw them, just fill the z-buffer so its used as a mask + + rd->draw_list_set_push_constant(draw_list, &seams_push_constant, sizeof(RasterSeamsPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_triangle_count[i] * 3); + } + /* step 3 blend around the triangle */ + + rd->draw_list_bind_render_pipeline(draw_list, blendseams_line_raster_pipeline); + + for (int j = 1; j < uv_offset_count; j++) { + + seams_push_constant.base_index = seam_offset; + seams_push_constant.uv_offset[0] = uv_offsets[j].x / float(atlas_size.width); + seams_push_constant.uv_offset[1] = uv_offsets[j].y / float(atlas_size.height); + seams_push_constant.blend = uv_offsets[0].z; + + rd->draw_list_set_push_constant(draw_list, &seams_push_constant, sizeof(RasterSeamsPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_seam_count[i] * 4); + } + rd->draw_list_end(); + } + seam_offset += slice_seam_count[i]; + triangle_offset += slice_triangle_count[i]; + } + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector s = rd->texture_get_data(light_accum_tex, i); + Ref img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://5_blendseams" + itos(i) + ".png"); + } +#endif + if (p_step_function) { + p_step_function(0.9, TTR("Retrieving textures"), p_bake_userdata, true); + } + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector s = rd->texture_get_data(light_accum_tex, i); + Ref img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBH); //remove alpha + bake_textures.push_back(img); + } + + if (probe_positions.size() > 0) { + probe_values.resize(probe_positions.size() * 9); + Vector probe_data = rd->buffer_get_data(light_probe_buffer); + copymem(probe_values.ptrw(), probe_data.ptr(), probe_data.size()); + rd->free(light_probe_buffer); + +#ifdef DEBUG_TEXTURES + { + Ref img2; + img2.instance(); + img2->create(probe_values.size(), 1, false, Image::FORMAT_RGBAF, probe_data); + img2->save_png("res://6_lightprobes.png"); + } +#endif + } + + FREE_TEXTURES + FREE_BUFFERS + FREE_RASTER_RESOURCES + FREE_COMPUTE_RESOURCES + FREE_BLENDSEAMS_RESOURCES + + memdelete(rd); + + return BAKE_OK; +} + +int LightmapperRD::get_bake_texture_count() const { + return bake_textures.size(); +} +Ref LightmapperRD::get_bake_texture(int p_index) const { + ERR_FAIL_INDEX_V(p_index, bake_textures.size(), Ref()); + return bake_textures[p_index]; +} +int LightmapperRD::get_bake_mesh_count() const { + return mesh_instances.size(); +} +Variant LightmapperRD::get_bake_mesh_userdata(int p_index) const { + ERR_FAIL_INDEX_V(p_index, mesh_instances.size(), Variant()); + return mesh_instances[p_index].data.userdata; +} +Rect2 LightmapperRD::get_bake_mesh_uv_scale(int p_index) const { + + ERR_FAIL_COND_V(bake_textures.size() == 0, Rect2()); + Rect2 uv_ofs; + Vector2 atlas_size = Vector2(bake_textures[0]->get_width(), bake_textures[0]->get_height()); + uv_ofs.position = Vector2(mesh_instances[p_index].offset) / atlas_size; + uv_ofs.size = Vector2(mesh_instances[p_index].data.albedo_on_uv2->get_width(), mesh_instances[p_index].data.albedo_on_uv2->get_height()) / atlas_size; + return uv_ofs; +} +int LightmapperRD::get_bake_mesh_texture_slice(int p_index) const { + ERR_FAIL_INDEX_V(p_index, mesh_instances.size(), Variant()); + return mesh_instances[p_index].slice; +} + +int LightmapperRD::get_bake_probe_count() const { + return probe_positions.size(); +} + +Vector3 LightmapperRD::get_bake_probe_point(int p_probe) const { + ERR_FAIL_INDEX_V(p_probe, probe_positions.size(), Variant()); + return Vector3(probe_positions[p_probe].position[0], probe_positions[p_probe].position[1], probe_positions[p_probe].position[2]); +} + +Vector LightmapperRD::get_bake_probe_sh(int p_probe) const { + ERR_FAIL_INDEX_V(p_probe, probe_positions.size(), Vector()); + Vector ret; + ret.resize(9); + copymem(ret.ptrw(), &probe_values[p_probe * 9], sizeof(Color) * 9); + return ret; +} + +LightmapperRD::LightmapperRD() { +} diff --git a/modules/lightmapper_rd/lightmapper_rd.h b/modules/lightmapper_rd/lightmapper_rd.h new file mode 100644 index 0000000000..cb98efbeaa --- /dev/null +++ b/modules/lightmapper_rd/lightmapper_rd.h @@ -0,0 +1,229 @@ +#ifndef LIGHTMAPPER_RD_H +#define LIGHTMAPPER_RD_H + +#include "core/local_vector.h" +#include "scene/3d/lightmapper.h" +#include "scene/resources/mesh.h" +#include "servers/rendering/rendering_device.h" + +class LightmapperRD : public Lightmapper { + GDCLASS(LightmapperRD, Lightmapper) + + struct MeshInstance { + MeshData data; + int slice = 0; + Vector2i offset; + }; + + struct Light { + float position[3]; + uint32_t type = LIGHT_TYPE_DIRECTIONAL; + float direction[3]; + float energy; + float color[3]; + float size; + float range; + float attenuation; + float spot_angle; + float spot_attenuation; + uint32_t static_bake; + uint32_t pad[3]; + + bool operator<(const Light &p_light) const { + return type < p_light.type; + } + }; + + struct Vertex { + float position[3]; + float normal_z; + float uv[2]; + float normal_xy[2]; + + bool operator==(const Vertex &p_vtx) const { + return (position[0] == p_vtx.position[0]) && + (position[1] == p_vtx.position[1]) && + (position[2] == p_vtx.position[2]) && + (uv[0] == p_vtx.uv[0]) && + (uv[1] == p_vtx.uv[1]) && + (normal_xy[0] == p_vtx.normal_xy[0]) && + (normal_xy[1] == p_vtx.normal_xy[1]) && + (normal_z == p_vtx.normal_z); + } + }; + + struct Edge { + Vector3 a; + Vector3 b; + Vector3 na; + Vector3 nb; + bool operator==(const Edge &p_seam) const { + return a == p_seam.a && b == p_seam.b && na == p_seam.na && nb == p_seam.nb; + } + Edge() { + } + + Edge(const Vector3 &p_a, const Vector3 &p_b, const Vector3 &p_na, const Vector3 &p_nb) { + a = p_a; + b = p_b; + na = p_na; + nb = p_nb; + } + }; + + struct Probe { + float position[4]; + }; + + Vector probe_positions; + + struct EdgeHash { + _FORCE_INLINE_ static uint32_t hash(const Edge &p_edge) { + uint32_t h = hash_djb2_one_float(p_edge.a.x); + h = hash_djb2_one_float(p_edge.a.y, h); + h = hash_djb2_one_float(p_edge.a.z, h); + h = hash_djb2_one_float(p_edge.b.x, h); + h = hash_djb2_one_float(p_edge.b.y, h); + h = hash_djb2_one_float(p_edge.b.z, h); + return h; + } + }; + struct EdgeUV2 { + Vector2 a; + Vector2 b; + Vector2i indices; + bool operator==(const EdgeUV2 &p_uv2) const { + return a == p_uv2.a && b == p_uv2.b; + } + bool seam_found = false; + EdgeUV2(Vector2 p_a, Vector2 p_b, Vector2i p_indices) { + a = p_a; + b = p_b; + indices = p_indices; + } + EdgeUV2() {} + }; + + struct Seam { + Vector2i a; + Vector2i b; + uint32_t slice; + bool operator<(const Seam &p_seam) const { + return slice < p_seam.slice; + } + }; + + struct VertexHash { + _FORCE_INLINE_ static uint32_t hash(const Vertex &p_vtx) { + uint32_t h = hash_djb2_one_float(p_vtx.position[0]); + h = hash_djb2_one_float(p_vtx.position[1], h); + h = hash_djb2_one_float(p_vtx.position[2], h); + h = hash_djb2_one_float(p_vtx.uv[0], h); + h = hash_djb2_one_float(p_vtx.uv[1], h); + h = hash_djb2_one_float(p_vtx.normal_xy[0], h); + h = hash_djb2_one_float(p_vtx.normal_xy[1], h); + h = hash_djb2_one_float(p_vtx.normal_z, h); + return h; + } + }; + + struct Box { + float min_bounds[3]; + float pad0; + float max_bounds[3]; + float pad1; + }; + + struct Triangle { + uint32_t indices[3]; + uint32_t slice; + bool operator<(const Triangle &p_triangle) const { + return slice < p_triangle.slice; + } + }; + + Vector mesh_instances; + + Vector lights; + + struct TriangleSort { + uint32_t cell_index; + uint32_t triangle_index; + bool operator<(const TriangleSort &p_triangle_sort) const { + return cell_index < p_triangle_sort.cell_index; //sorting by triangle index in this case makes no sense + } + }; + + void _plot_triangle_into_triangle_index_list(int p_size, const Vector3i &p_ofs, const AABB &p_bounds, const Vector3 p_points[], uint32_t p_triangle_index, LocalVector &triangles, uint32_t p_grid_size); + + struct RasterPushConstant { + float atlas_size[2]; + float uv_offset[2]; + float to_cell_size[3]; + uint32_t base_triangle; + float to_cell_offset[3]; + float bias; + int32_t grid_size[3]; + uint32_t pad2; + }; + + struct RasterSeamsPushConstant { + + uint32_t base_index; + uint32_t slice; + float uv_offset[2]; + uint32_t debug; + float blend; + uint32_t pad[2]; + }; + + struct PushConstant { + int32_t atlas_size[2]; + uint32_t ray_count; + uint32_t ray_to; + + float world_size[3]; + float bias; + + float to_cell_offset[3]; + uint32_t ray_from; + + float to_cell_size[3]; + uint32_t light_count; + + int32_t grid_size; + int32_t atlas_slice; + int32_t region_ofs[2]; + + float environment_xform[12]; + }; + + Vector> bake_textures; + Vector probe_values; + + BakeError _blit_meshes_into_atlas(int p_max_texture_size, Vector> &albedo_images, Vector> &emission_images, AABB &bounds, Size2i &atlas_size, int &atlas_slices, BakeStepFunc p_step_function, void *p_bake_userdata); + void _create_acceleration_structures(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, AABB &bounds, int grid_size, Vector &probe_positions, GenerateProbes p_generate_probes, Vector &slice_triangle_count, Vector &slice_seam_count, RID &vertex_buffer, RID &triangle_buffer, RID &box_buffer, RID &lights_buffer, RID &triangle_cell_indices_buffer, RID &probe_positions_buffer, RID &grid_texture, RID &grid_texture_sdf, RID &seams_buffer, BakeStepFunc p_step_function, void *p_bake_userdata); + void _raster_geometry(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, int grid_size, AABB bounds, float p_bias, Vector slice_triangle_count, RID position_tex, RID unocclude_tex, RID normal_tex, RID raster_depth_buffer, RID rasterize_shader, RID raster_base_uniform); + +public: + virtual void add_mesh(const MeshData &p_mesh); + virtual void add_directional_light(bool p_static, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_angular_distance); + virtual void add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size); + virtual void add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size); + virtual void add_probe(const Vector3 &p_position); + virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_bake_userdata = nullptr); + + int get_bake_texture_count() const; + Ref get_bake_texture(int p_index) const; + int get_bake_mesh_count() const; + Variant get_bake_mesh_userdata(int p_index) const; + Rect2 get_bake_mesh_uv_scale(int p_index) const; + int get_bake_mesh_texture_slice(int p_index) const; + int get_bake_probe_count() const; + Vector3 get_bake_probe_point(int p_probe) const; + Vector get_bake_probe_sh(int p_probe) const; + + LightmapperRD(); +}; + +#endif // LIGHTMAPPER_H diff --git a/modules/lightmapper_rd/lm_blendseams.glsl b/modules/lightmapper_rd/lm_blendseams.glsl new file mode 100644 index 0000000000..ef1ece8ea1 --- /dev/null +++ b/modules/lightmapper_rd/lm_blendseams.glsl @@ -0,0 +1,117 @@ +/* clang-format off */ +[versions] + +lines = "#define MODE_LINES" +triangles = "#define MODE_TRIANGLES" + +[vertex] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + /* clang-format on */ + + layout(push_constant, binding = 0, std430) uniform Params { + uint base_index; + uint slice; + vec2 uv_offset; + bool debug; + float blend; + uint pad[2]; + } params; + +layout(location = 0) out vec3 uv_interp; + +void main() { + +#ifdef MODE_TRIANGLES + + uint triangle_idx = params.base_index + gl_VertexIndex / 3; + uint triangle_subidx = gl_VertexIndex % 3; + + vec2 uv; + if (triangle_subidx == 0) { + uv = vertices.data[triangles.data[triangle_idx].indices.x].uv; + } else if (triangle_subidx == 1) { + uv = vertices.data[triangles.data[triangle_idx].indices.y].uv; + } else { + uv = vertices.data[triangles.data[triangle_idx].indices.z].uv; + } + + uv_interp = vec3(uv, float(params.slice)); + gl_Position = vec4((uv + params.uv_offset) * 2.0 - 1.0, 0.0001, 1.0); + +#endif + +#ifdef MODE_LINES + uint seam_idx = params.base_index + gl_VertexIndex / 4; + uint seam_subidx = gl_VertexIndex % 4; + + uint src_idx; + uint dst_idx; + + if (seam_subidx == 0) { + src_idx = seams.data[seam_idx].b.x; + dst_idx = seams.data[seam_idx].a.x; + } else if (seam_subidx == 1) { + src_idx = seams.data[seam_idx].b.y; + dst_idx = seams.data[seam_idx].a.y; + } else if (seam_subidx == 2) { + src_idx = seams.data[seam_idx].a.x; + dst_idx = seams.data[seam_idx].b.x; + } else if (seam_subidx == 3) { + src_idx = seams.data[seam_idx].a.y; + dst_idx = seams.data[seam_idx].b.y; + } + + vec2 src_uv = vertices.data[src_idx].uv; + vec2 dst_uv = vertices.data[dst_idx].uv + params.uv_offset; + + uv_interp = vec3(src_uv, float(params.slice)); + gl_Position = vec4(dst_uv * 2.0 - 1.0, 0.0001, 1.0); + ; +#endif +} + +/* clang-format off */ +[fragment] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + /* clang-format on */ + + layout(push_constant, binding = 0, std430) uniform Params { + uint base_index; + uint slice; + vec2 uv_offset; + bool debug; + float blend; + uint pad[2]; + } params; + +layout(location = 0) in vec3 uv_interp; + +layout(location = 0) out vec4 dst_color; + +layout(set = 1, binding = 0) uniform texture2DArray src_color_tex; + +void main() { + + if (params.debug) { +#ifdef MODE_TRIANGLES + dst_color = vec4(1, 0, 1, 1); +#else + dst_color = vec4(1, 1, 0, 1); +#endif + } else { + vec4 src_color = textureLod(sampler2DArray(src_color_tex, linear_sampler), uv_interp, 0.0); + dst_color = vec4(src_color.rgb, params.blend); //mix + } +} diff --git a/modules/lightmapper_rd/lm_common_inc.glsl b/modules/lightmapper_rd/lm_common_inc.glsl new file mode 100644 index 0000000000..0ff455936e --- /dev/null +++ b/modules/lightmapper_rd/lm_common_inc.glsl @@ -0,0 +1,92 @@ + +/* SET 0, static data that does not change between any call */ + +struct Vertex { + vec3 position; + float normal_z; + vec2 uv; + vec2 normal_xy; +}; + +layout(set = 0, binding = 1, std430) restrict readonly buffer Vertices { + Vertex data[]; +} +vertices; + +struct Triangle { + uvec3 indices; + uint slice; +}; + +layout(set = 0, binding = 2, std430) restrict readonly buffer Triangles { + Triangle data[]; +} +triangles; + +struct Box { + vec3 min_bounds; + uint pad0; + vec3 max_bounds; + uint pad1; +}; + +layout(set = 0, binding = 3, std430) restrict readonly buffer Boxes { + Box data[]; +} +boxes; + +layout(set = 0, binding = 4, std430) restrict readonly buffer GridIndices { + uint data[]; +} +grid_indices; + +#define LIGHT_TYPE_DIRECTIONAL 0 +#define LIGHT_TYPE_OMNI 1 +#define LIGHT_TYPE_SPOT 2 + +struct Light { + vec3 position; + uint type; + + vec3 direction; + float energy; + + vec3 color; + float size; + + float range; + float attenuation; + float spot_angle; + float spot_attenuation; + + bool static_bake; + uint pad[3]; +}; + +layout(set = 0, binding = 5, std430) restrict readonly buffer Lights { + Light data[]; +} +lights; + +struct Seam { + uvec2 a; + uvec2 b; +}; + +layout(set = 0, binding = 6, std430) restrict readonly buffer Seams { + Seam data[]; +} +seams; + +layout(set = 0, binding = 7, std430) restrict readonly buffer Probes { + vec4 data[]; +} +probe_positions; + +layout(set = 0, binding = 8) uniform utexture3D grid; +layout(set = 0, binding = 9) uniform texture3D grid_sdf; + +layout(set = 0, binding = 10) uniform texture2DArray albedo_tex; +layout(set = 0, binding = 11) uniform texture2DArray emission_tex; + +layout(set = 0, binding = 12) uniform sampler linear_sampler; diff --git a/modules/lightmapper_rd/lm_compute.glsl b/modules/lightmapper_rd/lm_compute.glsl new file mode 100644 index 0000000000..a178bd9b2e --- /dev/null +++ b/modules/lightmapper_rd/lm_compute.glsl @@ -0,0 +1,657 @@ +/* clang-format off */ +[versions] + +primary = "#define MODE_DIRECT_LIGHT" +secondary = "#define MODE_BOUNCE_LIGHT" +dilate = "#define MODE_DILATE" +unocclude = "#define MODE_UNOCCLUDE" +light_probes = "#define MODE_LIGHT_PROBES" + +[compute] + +#version 450 + +VERSION_DEFINES + +// One 2D local group focusing in one layer at a time, though all +// in parallel (no barriers) makes more sense than a 3D local group +// as this can take more advantage of the cache for each group. + +#ifdef MODE_LIGHT_PROBES + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +#else + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +#endif + +#include "lm_common_inc.glsl" + +/* clang-format on */ + +#ifdef MODE_LIGHT_PROBES + +layout(set = 1, binding = 0, std430) restrict buffer LightProbeData { + vec4 data[]; +} +light_probes; + +layout(set = 1, binding = 1) uniform texture2DArray source_light; +layout(set = 1, binding = 2) uniform texture2DArray source_direct_light; //also need the direct light, which was omitted +layout(set = 1, binding = 3) uniform texture2D environment; +#endif + +#ifdef MODE_UNOCCLUDE + +layout(rgba32f, set = 1, binding = 0) uniform restrict image2DArray position; +layout(rgba32f, set = 1, binding = 1) uniform restrict readonly image2DArray unocclude; + +#endif + +#if defined(MODE_DIRECT_LIGHT) || defined(MODE_BOUNCE_LIGHT) + +layout(rgba16f, set = 1, binding = 0) uniform restrict writeonly image2DArray dest_light; +layout(set = 1, binding = 1) uniform texture2DArray source_light; +layout(set = 1, binding = 2) uniform texture2DArray source_position; +layout(set = 1, binding = 3) uniform texture2DArray source_normal; +layout(rgba16f, set = 1, binding = 4) uniform restrict image2DArray accum_light; + +#endif + +#ifdef MODE_BOUNCE_LIGHT +layout(rgba32f, set = 1, binding = 5) uniform restrict image2DArray bounce_accum; +layout(set = 1, binding = 6) uniform texture2D environment; +#endif +#ifdef MODE_DIRECT_LIGHT +layout(rgba32f, set = 1, binding = 5) uniform restrict writeonly image2DArray primary_dynamic; +#endif + +#ifdef MODE_DILATE +layout(rgba16f, set = 1, binding = 0) uniform restrict writeonly image2DArray dest_light; +layout(set = 1, binding = 1) uniform texture2DArray source_light; +#endif + +layout(push_constant, binding = 0, std430) uniform Params { + ivec2 atlas_size; // x used for light probe mode total probes + uint ray_count; + uint ray_to; + + vec3 world_size; + float bias; + + vec3 to_cell_offset; + uint ray_from; + + vec3 to_cell_size; + uint light_count; + + int grid_size; + int atlas_slice; + ivec2 region_ofs; + + mat3x4 env_transform; +} +params; + +//check it, but also return distance and barycentric coords (for uv lookup) +bool ray_hits_triangle(vec3 from, vec3 dir, float max_dist, vec3 p0, vec3 p1, vec3 p2, out float r_distance, out vec3 r_barycentric) { + + const vec3 e0 = p1 - p0; + const vec3 e1 = p0 - p2; + vec3 triangleNormal = cross(e1, e0); + + const vec3 e2 = (1.0 / dot(triangleNormal, dir)) * (p0 - from); + const vec3 i = cross(dir, e2); + + r_barycentric.y = dot(i, e1); + r_barycentric.z = dot(i, e0); + r_barycentric.x = 1.0 - (r_barycentric.z + r_barycentric.y); + r_distance = dot(triangleNormal, e2); + return (r_distance > params.bias) && (r_distance < max_dist) && all(greaterThanEqual(r_barycentric, vec3(0.0))); +} + +bool trace_ray(vec3 p_from, vec3 p_to +#if defined(MODE_BOUNCE_LIGHT) || defined(MODE_LIGHT_PROBES) + , + out uint r_triangle, out vec3 r_barycentric +#endif +#if defined(MODE_UNOCCLUDE) + , + out float r_distance, out vec3 r_normal +#endif +) { + + /* world coords */ + + vec3 rel = p_to - p_from; + float rel_len = length(rel); + vec3 dir = normalize(rel); + vec3 inv_dir = 1.0 / dir; + + /* cell coords */ + + vec3 from_cell = (p_from - params.to_cell_offset) * params.to_cell_size; + vec3 to_cell = (p_to - params.to_cell_offset) * params.to_cell_size; + + //prepare DDA + vec3 rel_cell = to_cell - from_cell; + ivec3 icell = ivec3(from_cell); + ivec3 iendcell = ivec3(to_cell); + vec3 dir_cell = normalize(rel_cell); + vec3 delta = abs(1.0 / dir_cell); //vec3(length(rel_cell)) / rel_cell); + ivec3 step = ivec3(sign(rel_cell)); + vec3 side = (sign(rel_cell) * (vec3(icell) - from_cell) + (sign(rel_cell) * 0.5) + 0.5) * delta; + + uint iters = 0; + while (all(greaterThanEqual(icell, ivec3(0))) && all(lessThan(icell, ivec3(params.grid_size))) && iters < 1000) { + + uvec2 cell_data = texelFetch(usampler3D(grid, linear_sampler), icell, 0).xy; + if (cell_data.x > 0) { //triangles here + + bool hit = false; +#if defined(MODE_UNOCCLUDE) + bool hit_backface = false; +#endif + float best_distance = 1e20; + + for (uint i = 0; i < cell_data.x; i++) { + uint tidx = grid_indices.data[cell_data.y + i]; + + //Ray-Box test + vec3 t0 = (boxes.data[tidx].min_bounds - p_from) * inv_dir; + vec3 t1 = (boxes.data[tidx].max_bounds - p_from) * inv_dir; + vec3 tmin = min(t0, t1), tmax = max(t0, t1); + + if (max(tmin.x, max(tmin.y, tmin.z)) <= min(tmax.x, min(tmax.y, tmax.z))) { + continue; //ray box failed + } + + //prepare triangle vertices + vec3 vtx0 = vertices.data[triangles.data[tidx].indices.x].position; + vec3 vtx1 = vertices.data[triangles.data[tidx].indices.y].position; + vec3 vtx2 = vertices.data[triangles.data[tidx].indices.z].position; +#if defined(MODE_UNOCCLUDE) + vec3 normal = -normalize(cross((vtx0 - vtx1), (vtx0 - vtx2))); + + bool backface = dot(normal, dir) >= 0.0; +#endif + float distance; + vec3 barycentric; + + if (ray_hits_triangle(p_from, dir, rel_len, vtx0, vtx1, vtx2, distance, barycentric)) { +#ifdef MODE_DIRECT_LIGHT + return true; //any hit good +#endif + +#if defined(MODE_UNOCCLUDE) + if (!backface) { + // the case of meshes having both a front and back face in the same plane is more common than + // expected, so if this is a front-face, bias it closer to the ray origin, so it always wins over the back-face + distance = max(params.bias, distance - params.bias); + } + + hit = true; + + if (distance < best_distance) { + hit_backface = backface; + best_distance = distance; + r_distance = distance; + r_normal = normal; + } + +#endif + +#if defined(MODE_BOUNCE_LIGHT) || defined(MODE_LIGHT_PROBES) + + hit = true; + if (distance < best_distance) { + best_distance = distance; + r_triangle = tidx; + r_barycentric = barycentric; + } + +#endif + } + } +#if defined(MODE_UNOCCLUDE) + + if (hit) { + return hit_backface; + } +#endif +#if defined(MODE_BOUNCE_LIGHT) || defined(MODE_LIGHT_PROBES) + if (hit) { + return true; + } +#endif + } + + if (icell == iendcell) { + break; + } + + bvec3 mask = lessThanEqual(side.xyz, min(side.yzx, side.zxy)); + side += vec3(mask) * delta; + icell += ivec3(vec3(mask)) * step; + + iters++; + } + + return false; +} + +const float PI = 3.14159265f; +const float GOLDEN_ANGLE = PI * (3.0 - sqrt(5.0)); + +vec3 vogel_hemisphere(uint p_index, uint p_count, float p_offset) { + float r = sqrt(float(p_index) + 0.5f) / sqrt(float(p_count)); + float theta = float(p_index) * GOLDEN_ANGLE + p_offset; + float y = cos(r * PI * 0.5); + float l = sin(r * PI * 0.5); + return vec3(l * cos(theta), l * sin(theta), y); +} + +float quick_hash(vec2 pos) { + return fract(sin(dot(pos * 19.19, vec2(49.5791, 97.413))) * 49831.189237); +} + +void main() { + +#ifdef MODE_LIGHT_PROBES + int probe_index = int(gl_GlobalInvocationID.x); + if (probe_index >= params.atlas_size.x) { //too large, do nothing + return; + } + +#else + ivec2 atlas_pos = ivec2(gl_GlobalInvocationID.xy) + params.region_ofs; + if (any(greaterThanEqual(atlas_pos, params.atlas_size))) { //too large, do nothing + return; + } +#endif + +#ifdef MODE_DIRECT_LIGHT + + vec3 normal = texelFetch(sampler2DArray(source_normal, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + if (length(normal) < 0.5) { + return; //empty texel, no process + } + vec3 position = texelFetch(sampler2DArray(source_position, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + + //go through all lights + //start by own light (emissive) + vec3 static_light = vec3(0.0); + vec3 dynamic_light = vec3(0.0); + +#ifdef USE_SH_LIGHTMAPS + vec4 sh_accum[4] = vec4[]( + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0)); +#endif + + for (uint i = 0; i < params.light_count; i++) { + + vec3 light_pos; + float attenuation; + if (lights.data[i].type == LIGHT_TYPE_DIRECTIONAL) { + vec3 light_vec = lights.data[i].direction; + light_pos = position - light_vec * length(params.world_size); + attenuation = 1.0; + } else { + light_pos = lights.data[i].position; + float d = distance(position, light_pos); + if (d > lights.data[i].range) { + continue; + } + + d /= lights.data[i].range; + + attenuation = pow(max(1.0 - d, 0.0), lights.data[i].attenuation); + + if (lights.data[i].type == LIGHT_TYPE_SPOT) { + + vec3 rel = normalize(position - light_pos); + float angle = acos(dot(rel, lights.data[i].direction)); + if (angle > lights.data[i].spot_angle) { + continue; //invisible, dont try + } + + float d = clamp(angle / lights.data[i].spot_angle, 0, 1); + attenuation *= pow(1.0 - d, lights.data[i].spot_attenuation); + } + } + + vec3 light_dir = normalize(light_pos - position); + attenuation *= max(0.0, dot(normal, light_dir)); + + if (attenuation <= 0.0001) { + continue; //no need to do anything + } + + if (!trace_ray(position + light_dir * params.bias, light_pos)) { + vec3 light = lights.data[i].color * lights.data[i].energy * attenuation; + if (lights.data[i].static_bake) { + static_light += light; +#ifdef USE_SH_LIGHTMAPS + + float c[4] = float[]( + 0.282095, //l0 + 0.488603 * light_dir.y, //l1n1 + 0.488603 * light_dir.z, //l1n0 + 0.488603 * light_dir.x //l1p1 + ); + + for (uint j = 0; j < 4; j++) { + sh_accum[j].rgb += light * c[j] * (1.0 / 3.0); + } +#endif + + } else { + dynamic_light += light; + } + } + } + + vec3 albedo = texelFetch(sampler2DArray(albedo_tex, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).rgb; + vec3 emissive = texelFetch(sampler2DArray(emission_tex, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).rgb; + + dynamic_light *= albedo; //if it will bounce, must multiply by albedo + dynamic_light += emissive; + + //keep for lightprobes + imageStore(primary_dynamic, ivec3(atlas_pos, params.atlas_slice), vec4(dynamic_light, 1.0)); + + dynamic_light += static_light * albedo; //send for bounces + imageStore(dest_light, ivec3(atlas_pos, params.atlas_slice), vec4(dynamic_light, 1.0)); + +#ifdef USE_SH_LIGHTMAPS + //keep for adding at the end + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 0), sh_accum[0]); + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 1), sh_accum[1]); + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 2), sh_accum[2]); + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 3), sh_accum[3]); + +#else + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice), vec4(static_light, 1.0)); +#endif + +#endif + +#ifdef MODE_BOUNCE_LIGHT + + vec3 normal = texelFetch(sampler2DArray(source_normal, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + if (length(normal) < 0.5) { + return; //empty texel, no process + } + + vec3 position = texelFetch(sampler2DArray(source_position, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + + vec3 v0 = abs(normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); + vec3 tangent = normalize(cross(v0, normal)); + vec3 bitangent = normalize(cross(tangent, normal)); + mat3 normal_mat = mat3(tangent, bitangent, normal); + +#ifdef USE_SH_LIGHTMAPS + vec4 sh_accum[4] = vec4[]( + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0)); +#endif + vec3 light_average = vec3(0.0); + for (uint i = params.ray_from; i < params.ray_to; i++) { + vec3 ray_dir = normal_mat * vogel_hemisphere(i, params.ray_count, quick_hash(vec2(atlas_pos))); + + uint tidx; + vec3 barycentric; + + vec3 light; + if (trace_ray(position + ray_dir * params.bias, position + ray_dir * length(params.world_size), tidx, barycentric)) { + //hit a triangle + vec2 uv0 = vertices.data[triangles.data[tidx].indices.x].uv; + vec2 uv1 = vertices.data[triangles.data[tidx].indices.y].uv; + vec2 uv2 = vertices.data[triangles.data[tidx].indices.z].uv; + vec3 uvw = vec3(barycentric.x * uv0 + barycentric.y * uv1 + barycentric.z * uv2, float(triangles.data[tidx].slice)); + + light = textureLod(sampler2DArray(source_light, linear_sampler), uvw, 0.0).rgb; + } else { + //did not hit a triangle, reach out for the sky + vec3 sky_dir = normalize(mat3(params.env_transform) * ray_dir); + + vec2 st = vec2( + atan(sky_dir.x, sky_dir.z), + acos(sky_dir.y)); + + if (st.x < 0.0) + st.x += PI * 2.0; + + st /= vec2(PI * 2.0, PI); + + light = textureLod(sampler2D(environment, linear_sampler), st, 0.0).rgb; + } + + light_average += light; + +#ifdef USE_SH_LIGHTMAPS + + float c[4] = float[]( + 0.282095, //l0 + 0.488603 * ray_dir.y, //l1n1 + 0.488603 * ray_dir.z, //l1n0 + 0.488603 * ray_dir.x //l1p1 + ); + + for (uint j = 0; j < 4; j++) { + sh_accum[j].rgb += light * c[j] * (8.0 / float(params.ray_count)); + } +#endif + } + + vec3 light_total; + if (params.ray_from == 0) { + light_total = vec3(0.0); + } else { + light_total = imageLoad(bounce_accum, ivec3(atlas_pos, params.atlas_slice)).rgb; + } + + light_total += light_average; + +#ifdef USE_SH_LIGHTMAPS + + for (int i = 0; i < 4; i++) { + vec4 accum = imageLoad(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + i)); + accum.rgb += sh_accum[i].rgb; + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + i), accum); + } + +#endif + if (params.ray_to == params.ray_count) { + light_total /= float(params.ray_count); + imageStore(dest_light, ivec3(atlas_pos, params.atlas_slice), vec4(light_total, 1.0)); +#ifndef USE_SH_LIGHTMAPS + vec4 accum = imageLoad(accum_light, ivec3(atlas_pos, params.atlas_slice)); + accum.rgb += light_total; + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice), accum); +#endif + } else { + imageStore(bounce_accum, ivec3(atlas_pos, params.atlas_slice), vec4(light_total, 1.0)); + } + +#endif + +#ifdef MODE_UNOCCLUDE + + //texel_size = 0.5; + //compute tangents + + vec4 position_alpha = imageLoad(position, ivec3(atlas_pos, params.atlas_slice)); + if (position_alpha.a < 0.5) { + return; + } + + vec3 vertex_pos = position_alpha.xyz; + vec4 normal_tsize = imageLoad(unocclude, ivec3(atlas_pos, params.atlas_slice)); + + vec3 face_normal = normal_tsize.xyz; + float texel_size = normal_tsize.w; + + vec3 v0 = abs(face_normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); + vec3 tangent = normalize(cross(v0, face_normal)); + vec3 bitangent = normalize(cross(tangent, face_normal)); + vec3 base_pos = vertex_pos + face_normal * params.bias; //raise a bit + + vec3 rays[4] = vec3[](tangent, bitangent, -tangent, -bitangent); + float min_d = 1e20; + for (int i = 0; i < 4; i++) { + vec3 ray_to = base_pos + rays[i] * texel_size; + float d; + vec3 norm; + + if (trace_ray(base_pos, ray_to, d, norm)) { + + if (d < min_d) { + vertex_pos = base_pos + rays[i] * d + norm * params.bias * 10.0; //this bias needs to be greater than the regular bias, because otherwise later, rays will go the other side when pointing back. + min_d = d; + } + } + } + + position_alpha.xyz = vertex_pos; + + imageStore(position, ivec3(atlas_pos, params.atlas_slice), position_alpha); + +#endif + +#ifdef MODE_LIGHT_PROBES + + vec3 position = probe_positions.data[probe_index].xyz; + + vec4 probe_sh_accum[9] = vec4[]( + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0)); + + for (uint i = params.ray_from; i < params.ray_to; i++) { + vec3 ray_dir = vogel_hemisphere(i, params.ray_count, quick_hash(vec2(float(probe_index), 0.0))); + if (bool(i & 1)) { + //throw to both sides, so alternate them + ray_dir.z *= -1.0; + } + + uint tidx; + vec3 barycentric; + vec3 light; + + if (trace_ray(position + ray_dir * params.bias, position + ray_dir * length(params.world_size), tidx, barycentric)) { + vec2 uv0 = vertices.data[triangles.data[tidx].indices.x].uv; + vec2 uv1 = vertices.data[triangles.data[tidx].indices.y].uv; + vec2 uv2 = vertices.data[triangles.data[tidx].indices.z].uv; + vec3 uvw = vec3(barycentric.x * uv0 + barycentric.y * uv1 + barycentric.z * uv2, float(triangles.data[tidx].slice)); + + light = textureLod(sampler2DArray(source_light, linear_sampler), uvw, 0.0).rgb; + light += textureLod(sampler2DArray(source_direct_light, linear_sampler), uvw, 0.0).rgb; + } else { + + //did not hit a triangle, reach out for the sky + vec3 sky_dir = normalize(mat3(params.env_transform) * ray_dir); + + vec2 st = vec2( + atan(sky_dir.x, sky_dir.z), + acos(sky_dir.y)); + + if (st.x < 0.0) + st.x += PI * 2.0; + + st /= vec2(PI * 2.0, PI); + + light = textureLod(sampler2D(environment, linear_sampler), st, 0.0).rgb; + } + + { + float c[9] = float[]( + 0.282095, //l0 + 0.488603 * ray_dir.y, //l1n1 + 0.488603 * ray_dir.z, //l1n0 + 0.488603 * ray_dir.x, //l1p1 + 1.092548 * ray_dir.x * ray_dir.y, //l2n2 + 1.092548 * ray_dir.y * ray_dir.z, //l2n1 + //0.315392 * (ray_dir.x * ray_dir.x + ray_dir.y * ray_dir.y + 2.0 * ray_dir.z * ray_dir.z), //l20 + 0.315392 * (3.0 * ray_dir.z * ray_dir.z - 1.0), //l20 + 1.092548 * ray_dir.x * ray_dir.z, //l2p1 + 0.546274 * (ray_dir.x * ray_dir.x - ray_dir.y * ray_dir.y) //l2p2 + ); + + for (uint j = 0; j < 9; j++) { + probe_sh_accum[j].rgb += light * c[j]; + } + } + } + + if (params.ray_from > 0) { + for (uint j = 0; j < 9; j++) { //accum from existing + probe_sh_accum[j] += light_probes.data[probe_index * 9 + j]; + } + } + + if (params.ray_to == params.ray_count) { + for (uint j = 0; j < 9; j++) { //accum from existing + probe_sh_accum[j] *= 4.0 / float(params.ray_count); + } + } + + for (uint j = 0; j < 9; j++) { //accum from existing + light_probes.data[probe_index * 9 + j] = probe_sh_accum[j]; + } + +#endif + +#ifdef MODE_DILATE + + vec4 c = texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0); + //sides first, as they are closer + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, 1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, -1), params.atlas_slice), 0); + //endpoints second + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, 1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, 1), params.atlas_slice), 0); + + //far sides third + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, 2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, -2), params.atlas_slice), 0); + + //far-mid endpoints + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, 1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, 1), params.atlas_slice), 0); + + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, 2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, 2), params.atlas_slice), 0); + //far endpoints + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, 2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, 2), params.atlas_slice), 0); + + imageStore(dest_light, ivec3(atlas_pos, params.atlas_slice), c); + +#endif +} diff --git a/modules/lightmapper_rd/lm_raster.glsl b/modules/lightmapper_rd/lm_raster.glsl new file mode 100644 index 0000000000..ae3038aead --- /dev/null +++ b/modules/lightmapper_rd/lm_raster.glsl @@ -0,0 +1,170 @@ +/* clang-format off */ +[vertex] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + /* clang-format on */ + + layout(location = 0) out vec3 vertex_interp; +layout(location = 1) out vec3 normal_interp; +layout(location = 2) out vec2 uv_interp; +layout(location = 3) out vec3 barycentric; +layout(location = 4) flat out uvec3 vertex_indices; +layout(location = 5) flat out vec3 face_normal; + +layout(push_constant, binding = 0, std430) uniform Params { + vec2 atlas_size; + vec2 uv_offset; + vec3 to_cell_size; + uint base_triangle; + vec3 to_cell_offset; + float bias; + ivec3 grid_size; + uint pad2; +} +params; + +/* clang-format on */ + +void main() { + + uint triangle_idx = params.base_triangle + gl_VertexIndex / 3; + uint triangle_subidx = gl_VertexIndex % 3; + + vertex_indices = triangles.data[triangle_idx].indices; + + uint vertex_idx; + if (triangle_subidx == 0) { + vertex_idx = vertex_indices.x; + barycentric = vec3(1, 0, 0); + } else if (triangle_subidx == 1) { + vertex_idx = vertex_indices.y; + barycentric = vec3(0, 1, 0); + } else { + vertex_idx = vertex_indices.z; + barycentric = vec3(0, 0, 1); + } + + vertex_interp = vertices.data[vertex_idx].position; + uv_interp = vertices.data[vertex_idx].uv; + normal_interp = vec3(vertices.data[vertex_idx].normal_xy, vertices.data[vertex_idx].normal_z); + + face_normal = -normalize(cross((vertices.data[vertex_indices.x].position - vertices.data[vertex_indices.y].position), (vertices.data[vertex_indices.x].position - vertices.data[vertex_indices.z].position))); + + gl_Position = vec4((uv_interp + params.uv_offset) * 2.0 - 1.0, 0.0001, 1.0); + ; +} + +/* clang-format off */ + +[fragment] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + +layout(push_constant, binding = 0, std430) uniform Params { + vec2 atlas_size; + vec2 uv_offset; + vec3 to_cell_size; + uint base_triangle; + vec3 to_cell_offset; + float bias; + ivec3 grid_size; + uint pad2; +} params; + +/* clang-format on */ + +layout(location = 0) in vec3 vertex_interp; +layout(location = 1) in vec3 normal_interp; +layout(location = 2) in vec2 uv_interp; +layout(location = 3) in vec3 barycentric; +layout(location = 4) in flat uvec3 vertex_indices; +layout(location = 5) in flat vec3 face_normal; + +layout(location = 0) out vec4 position; +layout(location = 1) out vec4 normal; +layout(location = 2) out vec4 unocclude; + +void main() { + + vec3 vertex_pos = vertex_interp; + + { + // smooth out vertex position by interpolating its projection in the 3 normal planes (normal plane is created by vertex pos and normal) + // because we don't want to interpolate inwards, normals found pointing inwards are pushed out. + + vec3 pos_a = vertices.data[vertex_indices.x].position; + vec3 pos_b = vertices.data[vertex_indices.y].position; + vec3 pos_c = vertices.data[vertex_indices.z].position; + vec3 center = (pos_a + pos_b + pos_c) * 0.3333333; + vec3 norm_a = vec3(vertices.data[vertex_indices.x].normal_xy, vertices.data[vertex_indices.x].normal_z); + vec3 norm_b = vec3(vertices.data[vertex_indices.y].normal_xy, vertices.data[vertex_indices.y].normal_z); + vec3 norm_c = vec3(vertices.data[vertex_indices.z].normal_xy, vertices.data[vertex_indices.z].normal_z); + + { + vec3 dir_a = normalize(pos_a - center); + float d_a = dot(dir_a, norm_a); + if (d_a < 0) { + //pointing inwards + norm_a = normalize(norm_a - dir_a * d_a); + } + } + { + vec3 dir_b = normalize(pos_b - center); + float d_b = dot(dir_b, norm_b); + if (d_b < 0) { + //pointing inwards + norm_b = normalize(norm_b - dir_b * d_b); + } + } + { + vec3 dir_c = normalize(pos_c - center); + float d_c = dot(dir_c, norm_c); + if (d_c < 0) { + //pointing inwards + norm_c = normalize(norm_c - dir_c * d_c); + } + } + + float d_a = dot(norm_a, pos_a); + float d_b = dot(norm_b, pos_b); + float d_c = dot(norm_c, pos_c); + + vec3 proj_a = vertex_pos - norm_a * (dot(norm_a, vertex_pos) - d_a); + vec3 proj_b = vertex_pos - norm_b * (dot(norm_b, vertex_pos) - d_b); + vec3 proj_c = vertex_pos - norm_c * (dot(norm_c, vertex_pos) - d_c); + + vec3 smooth_position = proj_a * barycentric.x + proj_b * barycentric.y + proj_c * barycentric.z; + + if (dot(face_normal, smooth_position) > dot(face_normal, vertex_pos)) { //only project outwards + vertex_pos = smooth_position; + } + } + + { + // unocclusion technique based on: + // https://ndotl.wordpress.com/2018/08/29/baking-artifact-free-lightmaps/ + + /* compute texel size */ + vec3 delta_uv = max(abs(dFdx(vertex_interp)), abs(dFdy(vertex_interp))); + float texel_size = max(delta_uv.x, max(delta_uv.y, delta_uv.z)); + texel_size *= sqrt(2.0); //expand to unit box edge length (again, worst case) + + unocclude.xyz = face_normal; + unocclude.w = texel_size; + + //continued on lm_compute.glsl + } + + position = vec4(vertex_pos, 1.0); + normal = vec4(normalize(normal_interp), 1.0); +} diff --git a/modules/lightmapper_rd/register_types.cpp b/modules/lightmapper_rd/register_types.cpp new file mode 100644 index 0000000000..f3938f3190 --- /dev/null +++ b/modules/lightmapper_rd/register_types.cpp @@ -0,0 +1,64 @@ +/*************************************************************************/ +/* register_types.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "register_types.h" + +#include "core/project_settings.h" +#include "lightmapper_rd.h" +#include "scene/3d/lightmapper.h" + +#ifndef _3D_DISABLED +static Lightmapper *create_lightmapper_rd() { + return memnew(LightmapperRD); +} +#endif + +void register_lightmapper_rd_types() { + + GLOBAL_DEF("rendering/gpu_lightmapper/quality/low_quality_ray_count", 16); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/medium_quality_ray_count", 64); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/high_quality_ray_count", 256); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/ultra_quality_ray_count", 1024); + GLOBAL_DEF("rendering/gpu_lightmapper/performance/max_rays_per_pass", 32); + GLOBAL_DEF("rendering/gpu_lightmapper/performance/region_size", 512); + + GLOBAL_DEF("rendering/gpu_lightmapper/quality/low_quality_probe_ray_count", 64); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/medium_quality_probe_ray_count", 256); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/high_quality_probe_ray_count", 512); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/ultra_quality_probe_ray_count", 2048); + GLOBAL_DEF("rendering/gpu_lightmapper/performance/max_rays_per_probe_pass", 64); +#ifndef _3D_DISABLED + ClassDB::register_class(); + Lightmapper::create_gpu = create_lightmapper_rd; +#endif +} + +void unregister_lightmapper_rd_types() { +} diff --git a/modules/lightmapper_rd/register_types.h b/modules/lightmapper_rd/register_types.h new file mode 100644 index 0000000000..b0e15a927f --- /dev/null +++ b/modules/lightmapper_rd/register_types.h @@ -0,0 +1,37 @@ +/*************************************************************************/ +/* register_types.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LIGHTMAPPER_RD_REGISTER_TYPES_H +#define LIGHTMAPPER_RD_REGISTER_TYPES_H + +void register_lightmapper_rd_types(); +void unregister_lightmapper_rd_types(); + +#endif // XATLAS_UNWRAP_REGISTER_TYPES_H diff --git a/modules/tinyexr/image_saver_tinyexr.cpp b/modules/tinyexr/image_saver_tinyexr.cpp index 05080289bd..bc30f4e4fd 100644 --- a/modules/tinyexr/image_saver_tinyexr.cpp +++ b/modules/tinyexr/image_saver_tinyexr.cpp @@ -267,13 +267,21 @@ Error save_exr(const String &p_path, const Ref &p_img, bool p_grayscale) header.channels = channel_infos; header.pixel_types = pixel_types; header.requested_pixel_types = requested_pixel_types; + header.compression_type = TINYEXR_COMPRESSIONTYPE_PIZ; - CharString utf8_filename = p_path.utf8(); - const char *err; - int ret = SaveEXRImageToFile(&image, &header, utf8_filename.ptr(), &err); - if (ret != TINYEXR_SUCCESS) { + unsigned char *mem = nullptr; + const char *err = nullptr; + + size_t bytes = SaveEXRImageToMemory(&image, &header, &mem, &err); + + if (bytes == 0) { print_error(String("Saving EXR failed. Error: {0}").format(varray(err))); return ERR_FILE_CANT_WRITE; + } else { + FileAccessRef ref = FileAccess::open(p_path, FileAccess::WRITE); + ERR_FAIL_COND_V(!ref, ERR_FILE_CANT_WRITE); + ref->store_buffer(mem, bytes); + free(mem); } return OK; diff --git a/modules/xatlas_unwrap/register_types.cpp b/modules/xatlas_unwrap/register_types.cpp index 8c5525bed3..f77646ce28 100644 --- a/modules/xatlas_unwrap/register_types.cpp +++ b/modules/xatlas_unwrap/register_types.cpp @@ -137,6 +137,7 @@ bool xatlas_mesh_lightmap_unwrap_callback(float p_texel_size, const float *p_ver pack_options.maxChartSize = 4096; pack_options.blockAlign = true; + pack_options.padding = 1; pack_options.texelsPerUnit = 1.0 / p_texel_size; xatlas::Atlas *atlas = xatlas::Create(); diff --git a/platform/uwp/export/export.cpp b/platform/uwp/export/export.cpp index f6618bfb47..bee1ddfc99 100644 --- a/platform/uwp/export/export.cpp +++ b/platform/uwp/export/export.cpp @@ -750,7 +750,7 @@ class EditorExportPlatformUWP : public EditorExportPlatform { return false; } - bool _valid_image(const StreamTexture *p_image, int p_width, int p_height) const { + bool _valid_image(const StreamTexture2D *p_image, int p_width, int p_height) const { if (!p_image) { return false; @@ -887,22 +887,22 @@ class EditorExportPlatformUWP : public EditorExportPlatform { Vector _get_image_data(const Ref &p_preset, const String &p_path) { Vector data; - StreamTexture *image = nullptr; + StreamTexture2D *image = nullptr; if (p_path.find("StoreLogo") != -1) { - image = p_preset->get("images/store_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/store_logo"))); + image = p_preset->get("images/store_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/store_logo"))); } else if (p_path.find("Square44x44Logo") != -1) { - image = p_preset->get("images/square44x44_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square44x44_logo"))); + image = p_preset->get("images/square44x44_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square44x44_logo"))); } else if (p_path.find("Square71x71Logo") != -1) { - image = p_preset->get("images/square71x71_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square71x71_logo"))); + image = p_preset->get("images/square71x71_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square71x71_logo"))); } else if (p_path.find("Square150x150Logo") != -1) { - image = p_preset->get("images/square150x150_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square150x150_logo"))); + image = p_preset->get("images/square150x150_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square150x150_logo"))); } else if (p_path.find("Square310x310Logo") != -1) { - image = p_preset->get("images/square310x310_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square310x310_logo"))); + image = p_preset->get("images/square310x310_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/square310x310_logo"))); } else if (p_path.find("Wide310x150Logo") != -1) { - image = p_preset->get("images/wide310x150_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/wide310x150_logo"))); + image = p_preset->get("images/wide310x150_logo").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/wide310x150_logo"))); } else if (p_path.find("SplashScreen") != -1) { - image = p_preset->get("images/splash_screen").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/splash_screen"))); + image = p_preset->get("images/splash_screen").is_zero() ? nullptr : Object::cast_to(((Object *)p_preset->get("images/splash_screen"))); } else { ERR_PRINT("Unable to load logo"); } @@ -1066,13 +1066,13 @@ public: r_options->push_back(ExportOption(PropertyInfo(Variant::BOOL, "orientation/portrait_flipped"), true)); r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "images/background_color"), "transparent")); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/store_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square44x44_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square71x71_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square150x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square310x310_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/wide310x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/splash_screen", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/store_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square44x44_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square71x71_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square150x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square310x310_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/wide310x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/splash_screen", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); r_options->push_back(ExportOption(PropertyInfo(Variant::BOOL, "tiles/show_name_on_square150x150"), false)); r_options->push_back(ExportOption(PropertyInfo(Variant::BOOL, "tiles/show_name_on_wide310x150"), false)); @@ -1173,37 +1173,37 @@ public: err += TTR("Invalid background color.") + "\n"; } - if (!p_preset->get("images/store_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/store_logo"))), 50, 50)) { + if (!p_preset->get("images/store_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/store_logo"))), 50, 50)) { valid = false; err += TTR("Invalid Store Logo image dimensions (should be 50x50).") + "\n"; } - if (!p_preset->get("images/square44x44_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square44x44_logo"))), 44, 44)) { + if (!p_preset->get("images/square44x44_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square44x44_logo"))), 44, 44)) { valid = false; err += TTR("Invalid square 44x44 logo image dimensions (should be 44x44).") + "\n"; } - if (!p_preset->get("images/square71x71_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square71x71_logo"))), 71, 71)) { + if (!p_preset->get("images/square71x71_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square71x71_logo"))), 71, 71)) { valid = false; err += TTR("Invalid square 71x71 logo image dimensions (should be 71x71).") + "\n"; } - if (!p_preset->get("images/square150x150_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square150x150_logo"))), 150, 150)) { + if (!p_preset->get("images/square150x150_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square150x150_logo"))), 150, 150)) { valid = false; err += TTR("Invalid square 150x150 logo image dimensions (should be 150x150).") + "\n"; } - if (!p_preset->get("images/square310x310_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square310x310_logo"))), 310, 310)) { + if (!p_preset->get("images/square310x310_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/square310x310_logo"))), 310, 310)) { valid = false; err += TTR("Invalid square 310x310 logo image dimensions (should be 310x310).") + "\n"; } - if (!p_preset->get("images/wide310x150_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/wide310x150_logo"))), 310, 150)) { + if (!p_preset->get("images/wide310x150_logo").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/wide310x150_logo"))), 310, 150)) { valid = false; err += TTR("Invalid wide 310x150 logo image dimensions (should be 310x150).") + "\n"; } - if (!p_preset->get("images/splash_screen").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/splash_screen"))), 620, 300)) { + if (!p_preset->get("images/splash_screen").is_zero() && !_valid_image((Object::cast_to((Object *)p_preset->get("images/splash_screen"))), 620, 300)) { valid = false; err += TTR("Invalid splash screen image dimensions (should be 620x300).") + "\n"; } diff --git a/scene/3d/baked_lightmap.cpp b/scene/3d/baked_lightmap.cpp index 6bde56104e..6efe7f60b2 100644 --- a/scene/3d/baked_lightmap.cpp +++ b/scene/3d/baked_lightmap.cpp @@ -28,72 +28,24 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 #include "baked_lightmap.h" #include "core/io/config_file.h" #include "core/io/resource_saver.h" +#include "core/math/camera_matrix.h" +#include "core/math/delaunay_3d.h" #include "core/os/dir_access.h" +#include "core/os/file_access.h" #include "core/os/os.h" -#include "voxel_light_baker.h" +#include "core/sort_array.h" +#include "lightmap_probe.h" -void BakedLightmapData::set_bounds(const AABB &p_bounds) { +void BakedLightmapData::add_user(const NodePath &p_path, const Rect2 &p_uv_scale, int p_slice_index, int32_t p_sub_instance) { - bounds = p_bounds; - RS::get_singleton()->lightmap_capture_set_bounds(baked_light, p_bounds); -} - -AABB BakedLightmapData::get_bounds() const { - - return bounds; -} - -void BakedLightmapData::set_octree(const Vector &p_octree) { - - RS::get_singleton()->lightmap_capture_set_octree(baked_light, p_octree); -} - -Vector BakedLightmapData::get_octree() const { - - return RS::get_singleton()->lightmap_capture_get_octree(baked_light); -} - -void BakedLightmapData::set_cell_space_transform(const Transform &p_xform) { - - cell_space_xform = p_xform; - RS::get_singleton()->lightmap_capture_set_octree_cell_transform(baked_light, p_xform); -} - -Transform BakedLightmapData::get_cell_space_transform() const { - return cell_space_xform; -} - -void BakedLightmapData::set_cell_subdiv(int p_cell_subdiv) { - cell_subdiv = p_cell_subdiv; - RS::get_singleton()->lightmap_capture_set_octree_cell_subdiv(baked_light, p_cell_subdiv); -} - -int BakedLightmapData::get_cell_subdiv() const { - return cell_subdiv; -} - -void BakedLightmapData::set_energy(float p_energy) { - - energy = p_energy; - RS::get_singleton()->lightmap_capture_set_energy(baked_light, energy); -} - -float BakedLightmapData::get_energy() const { - - return energy; -} - -void BakedLightmapData::add_user(const NodePath &p_path, const Ref &p_lightmap, int p_instance) { - - ERR_FAIL_COND_MSG(p_lightmap.is_null(), "It's not a reference to a valid Texture object."); User user; user.path = p_path; - user.lightmap = p_lightmap; - user.instance_index = p_instance; + user.uv_scale = p_uv_scale; + user.slice_index = p_slice_index; + user.sub_instance = p_sub_instance; users.push_back(user); } @@ -106,16 +58,23 @@ NodePath BakedLightmapData::get_user_path(int p_user) const { ERR_FAIL_INDEX_V(p_user, users.size(), NodePath()); return users[p_user].path; } -Ref BakedLightmapData::get_user_lightmap(int p_user) const { - ERR_FAIL_INDEX_V(p_user, users.size(), Ref()); - return users[p_user].lightmap; -} - -int BakedLightmapData::get_user_instance(int p_user) const { +int32_t BakedLightmapData::get_user_sub_instance(int p_user) const { ERR_FAIL_INDEX_V(p_user, users.size(), -1); - return users[p_user].instance_index; + return users[p_user].sub_instance; +} + +Rect2 BakedLightmapData::get_user_lightmap_uv_scale(int p_user) const { + + ERR_FAIL_INDEX_V(p_user, users.size(), Rect2()); + return users[p_user].uv_scale; +} + +int BakedLightmapData::get_user_lightmap_slice_index(int p_user) const { + + ERR_FAIL_INDEX_V(p_user, users.size(), -1); + return users[p_user].slice_index; } void BakedLightmapData::clear_users() { @@ -124,10 +83,10 @@ void BakedLightmapData::clear_users() { void BakedLightmapData::_set_user_data(const Array &p_data) { - ERR_FAIL_COND((p_data.size() % 3) != 0); + ERR_FAIL_COND((p_data.size() % 4) != 0); - for (int i = 0; i < p_data.size(); i += 3) { - add_user(p_data[i], p_data[i + 1], p_data[i + 2]); + for (int i = 0; i < p_data.size(); i += 4) { + add_user(p_data[i + 0], p_data[i + 1], p_data[i + 2], p_data[i + 3]); } } @@ -136,522 +95,1132 @@ Array BakedLightmapData::_get_user_data() const { Array ret; for (int i = 0; i < users.size(); i++) { ret.push_back(users[i].path); - ret.push_back(users[i].lightmap); - ret.push_back(users[i].instance_index); + ret.push_back(users[i].uv_scale); + ret.push_back(users[i].slice_index); + ret.push_back(users[i].sub_instance); } return ret; } RID BakedLightmapData::get_rid() const { - return baked_light; + return lightmap; +} + +void BakedLightmapData::clear() { + users.clear(); +} + +void BakedLightmapData::set_light_texture(const Ref &p_light_texture) { + light_texture = p_light_texture; + RS::get_singleton()->lightmap_set_textures(lightmap, light_texture.is_valid() ? light_texture->get_rid() : RID(), uses_spherical_harmonics); +} + +Ref BakedLightmapData::get_light_texture() const { + return light_texture; +} + +void BakedLightmapData::set_uses_spherical_harmonics(bool p_enable) { + uses_spherical_harmonics = p_enable; + RS::get_singleton()->lightmap_set_textures(lightmap, light_texture.is_valid() ? light_texture->get_rid() : RID(), uses_spherical_harmonics); +} + +bool BakedLightmapData::is_using_spherical_harmonics() const { + return uses_spherical_harmonics; +} + +void BakedLightmapData::set_capture_data(const AABB &p_bounds, bool p_interior, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) { + if (p_points.size()) { + int pc = p_points.size(); + ERR_FAIL_COND(pc * 9 != p_point_sh.size()); + ERR_FAIL_COND((p_tetrahedra.size() % 4) != 0); + ERR_FAIL_COND((p_bsp_tree.size() % 6) != 0); + RS::get_singleton()->lightmap_set_probe_capture_data(lightmap, p_points, p_point_sh, p_tetrahedra, p_bsp_tree); + RS::get_singleton()->lightmap_set_probe_bounds(lightmap, p_bounds); + RS::get_singleton()->lightmap_set_probe_interior(lightmap, p_interior); + } else { + RS::get_singleton()->lightmap_set_probe_capture_data(lightmap, PackedVector3Array(), PackedColorArray(), PackedInt32Array(), PackedInt32Array()); + RS::get_singleton()->lightmap_set_probe_bounds(lightmap, AABB()); + RS::get_singleton()->lightmap_set_probe_interior(lightmap, false); + } + interior = p_interior; + bounds = p_bounds; +} + +PackedVector3Array BakedLightmapData::get_capture_points() const { + return RS::get_singleton()->lightmap_get_probe_capture_points(lightmap); +} +PackedColorArray BakedLightmapData::get_capture_sh() const { + return RS::get_singleton()->lightmap_get_probe_capture_sh(lightmap); +} +PackedInt32Array BakedLightmapData::get_capture_tetrahedra() const { + return RS::get_singleton()->lightmap_get_probe_capture_tetrahedra(lightmap); +} + +PackedInt32Array BakedLightmapData::get_capture_bsp_tree() const { + return RS::get_singleton()->lightmap_get_probe_capture_bsp_tree(lightmap); +} + +AABB BakedLightmapData::get_capture_bounds() const { + return bounds; +} + +bool BakedLightmapData::is_interior() const { + return interior; +} + +void BakedLightmapData::_set_probe_data(const Dictionary &p_data) { + ERR_FAIL_COND(!p_data.has("bounds")); + ERR_FAIL_COND(!p_data.has("points")); + ERR_FAIL_COND(!p_data.has("tetrahedra")); + ERR_FAIL_COND(!p_data.has("bsp")); + ERR_FAIL_COND(!p_data.has("sh")); + ERR_FAIL_COND(!p_data.has("interior")); + set_capture_data(p_data["bounds"], p_data["interior"], p_data["points"], p_data["sh"], p_data["tetrahedra"], p_data["bsp"]); +} + +Dictionary BakedLightmapData::_get_probe_data() const { + Dictionary d; + d["bounds"] = get_capture_bounds(); + d["points"] = get_capture_points(); + d["tetrahedra"] = get_capture_tetrahedra(); + d["bsp"] = get_capture_bsp_tree(); + d["sh"] = get_capture_sh(); + d["interior"] = is_interior(); + return d; } void BakedLightmapData::_bind_methods() { ClassDB::bind_method(D_METHOD("_set_user_data", "data"), &BakedLightmapData::_set_user_data); ClassDB::bind_method(D_METHOD("_get_user_data"), &BakedLightmapData::_get_user_data); - ClassDB::bind_method(D_METHOD("set_bounds", "bounds"), &BakedLightmapData::set_bounds); - ClassDB::bind_method(D_METHOD("get_bounds"), &BakedLightmapData::get_bounds); + ClassDB::bind_method(D_METHOD("set_light_texture", "light_texture"), &BakedLightmapData::set_light_texture); + ClassDB::bind_method(D_METHOD("get_light_texture"), &BakedLightmapData::get_light_texture); - ClassDB::bind_method(D_METHOD("set_cell_space_transform", "xform"), &BakedLightmapData::set_cell_space_transform); - ClassDB::bind_method(D_METHOD("get_cell_space_transform"), &BakedLightmapData::get_cell_space_transform); + ClassDB::bind_method(D_METHOD("set_uses_spherical_harmonics", "uses_spherical_harmonics"), &BakedLightmapData::set_uses_spherical_harmonics); + ClassDB::bind_method(D_METHOD("is_using_spherical_harmonics"), &BakedLightmapData::is_using_spherical_harmonics); - ClassDB::bind_method(D_METHOD("set_cell_subdiv", "cell_subdiv"), &BakedLightmapData::set_cell_subdiv); - ClassDB::bind_method(D_METHOD("get_cell_subdiv"), &BakedLightmapData::get_cell_subdiv); - - ClassDB::bind_method(D_METHOD("set_octree", "octree"), &BakedLightmapData::set_octree); - ClassDB::bind_method(D_METHOD("get_octree"), &BakedLightmapData::get_octree); - - ClassDB::bind_method(D_METHOD("set_energy", "energy"), &BakedLightmapData::set_energy); - ClassDB::bind_method(D_METHOD("get_energy"), &BakedLightmapData::get_energy); - - ClassDB::bind_method(D_METHOD("add_user", "path", "lightmap", "instance"), &BakedLightmapData::add_user); + ClassDB::bind_method(D_METHOD("add_user", "path", "lightmap", "offset"), &BakedLightmapData::add_user); ClassDB::bind_method(D_METHOD("get_user_count"), &BakedLightmapData::get_user_count); ClassDB::bind_method(D_METHOD("get_user_path", "user_idx"), &BakedLightmapData::get_user_path); - ClassDB::bind_method(D_METHOD("get_user_lightmap", "user_idx"), &BakedLightmapData::get_user_lightmap); ClassDB::bind_method(D_METHOD("clear_users"), &BakedLightmapData::clear_users); - ADD_PROPERTY(PropertyInfo(Variant::AABB, "bounds", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_bounds", "get_bounds"); - ADD_PROPERTY(PropertyInfo(Variant::TRANSFORM, "cell_space_transform", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_cell_space_transform", "get_cell_space_transform"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "cell_subdiv", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_cell_subdiv", "get_cell_subdiv"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "energy", PROPERTY_HINT_RANGE, "0,16,0.01,or_greater"), "set_energy", "get_energy"); - ADD_PROPERTY(PropertyInfo(Variant::PACKED_BYTE_ARRAY, "octree", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_octree", "get_octree"); + ClassDB::bind_method(D_METHOD("_set_probe_data", "data"), &BakedLightmapData::_set_probe_data); + ClassDB::bind_method(D_METHOD("_get_probe_data"), &BakedLightmapData::_get_probe_data); + + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "light_texture", PROPERTY_HINT_RESOURCE_TYPE, "TextureLayered"), "set_light_texture", "get_light_texture"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "uses_spherical_harmonics", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "set_uses_spherical_harmonics", "is_using_spherical_harmonics"); ADD_PROPERTY(PropertyInfo(Variant::ARRAY, "user_data", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "_set_user_data", "_get_user_data"); + ADD_PROPERTY(PropertyInfo(Variant::DICTIONARY, "probe_data", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "_set_probe_data", "_get_probe_data"); } BakedLightmapData::BakedLightmapData() { - baked_light = RS::get_singleton()->lightmap_capture_create(); - energy = 1; - cell_subdiv = 1; + lightmap = RS::get_singleton()->lightmap_create(); } BakedLightmapData::~BakedLightmapData() { - RS::get_singleton()->free(baked_light); + RS::get_singleton()->free(lightmap); } /////////////////////////// -BakedLightmap::BakeBeginFunc BakedLightmap::bake_begin_function = nullptr; -BakedLightmap::BakeStepFunc BakedLightmap::bake_step_function = nullptr; -BakedLightmap::BakeEndFunc BakedLightmap::bake_end_function = nullptr; +void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, Vector &meshes, Vector &lights, Vector &probes) { -void BakedLightmap::set_bake_cell_size(float p_cell_size) { - bake_cell_size = p_cell_size; -} - -float BakedLightmap::get_bake_cell_size() const { - return bake_cell_size; -} - -void BakedLightmap::set_capture_cell_size(float p_cell_size) { - capture_cell_size = p_cell_size; -} - -float BakedLightmap::get_capture_cell_size() const { - return capture_cell_size; -} - -void BakedLightmap::set_extents(const Vector3 &p_extents) { - extents = p_extents; - update_gizmo(); - _change_notify("bake_extents"); -} - -Vector3 BakedLightmap::get_extents() const { - return extents; -} - -void BakedLightmap::set_bake_default_texels_per_unit(const float &p_bake_texels_per_unit) { - bake_default_texels_per_unit = p_bake_texels_per_unit; - update_gizmo(); -} - -float BakedLightmap::get_bake_default_texels_per_unit() const { - return bake_default_texels_per_unit; -} - -void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, List &plot_meshes, List &plot_lights) { - - MeshInstance *mi = Object::cast_to(p_at_node); - if (mi && mi->get_flag(GeometryInstance::FLAG_USE_BAKED_LIGHT) && mi->is_visible_in_tree()) { + MeshInstance3D *mi = Object::cast_to(p_at_node); + if (mi && mi->get_gi_mode() == GeometryInstance3D::GI_MODE_BAKED && mi->is_visible_in_tree()) { Ref mesh = mi->get_mesh(); if (mesh.is_valid()) { - bool all_have_uv2 = true; + bool all_have_uv2_and_normal = true; + bool surfaces_found = false; for (int i = 0; i < mesh->get_surface_count(); i++) { + + if (mesh->surface_get_primitive_type(i) != Mesh::PRIMITIVE_TRIANGLES) { + continue; + } if (!(mesh->surface_get_format(i) & Mesh::ARRAY_FORMAT_TEX_UV2)) { - all_have_uv2 = false; + all_have_uv2_and_normal = false; break; } + if (!(mesh->surface_get_format(i) & Mesh::ARRAY_FORMAT_NORMAL)) { + all_have_uv2_and_normal = false; + break; + } + surfaces_found = true; } - if (all_have_uv2) { + if (surfaces_found && all_have_uv2_and_normal) { //READY TO BAKE! size hint could be computed if not found, actually.. - AABB aabb = mesh->get_aabb(); + MeshesFound mf; + mf.xform = get_global_transform().affine_inverse() * mi->get_global_transform(); + mf.node_path = get_path_to(mi); + mf.subindex = -1; + mf.mesh = mesh; - Transform xf = get_global_transform().affine_inverse() * mi->get_global_transform(); + static const int lightmap_scale[GeometryInstance3D::LIGHTMAP_SCALE_MAX] = { 1, 2, 4, 8 }; + mf.lightmap_scale = lightmap_scale[mi->get_lightmap_scale()]; - if (AABB(-extents, extents * 2).intersects(xf.xform(aabb))) { - PlotMesh pm; - pm.local_xform = xf; - pm.mesh = mesh; - pm.path = get_path_to(mi); - pm.instance_idx = -1; - for (int i = 0; i < mesh->get_surface_count(); i++) { - pm.instance_materials.push_back(mi->get_surface_material(i)); + Ref all_override = mi->get_material_override(); + for (int i = 0; i < mesh->get_surface_count(); i++) { + if (all_override.is_valid()) { + mf.overrides.push_back(all_override); + } else { + mf.overrides.push_back(mi->get_surface_material(i)); } - pm.override_material = mi->get_material_override(); - plot_meshes.push_back(pm); } + + meshes.push_back(mf); } } } - Spatial *s = Object::cast_to(p_at_node); + Node3D *s = Object::cast_to(p_at_node); if (!mi && s) { - Array meshes = p_at_node->call("get_bake_meshes"); - if (meshes.size() && (meshes.size() & 1) == 0) { + Array bmeshes = p_at_node->call("get_bake_bmeshes"); + if (bmeshes.size() && (bmeshes.size() & 1) == 0) { Transform xf = get_global_transform().affine_inverse() * s->get_global_transform(); - for (int i = 0; i < meshes.size(); i += 2) { - PlotMesh pm; - Transform mesh_xf = meshes[i + 1]; - pm.local_xform = xf * mesh_xf; - pm.mesh = meshes[i]; - pm.instance_idx = i / 2; - if (!pm.mesh.is_valid()) + for (int i = 0; i < bmeshes.size(); i += 2) { + + Ref mesh = bmeshes[i]; + if (!mesh.is_valid()) continue; - pm.path = get_path_to(s); - plot_meshes.push_back(pm); + + MeshesFound mf; + + Transform mesh_xf = bmeshes[i + 1]; + mf.xform = xf * mesh_xf; + mf.node_path = get_path_to(s); + mf.subindex = i / 2; + mf.lightmap_scale = 1; + mf.mesh = mesh; + + meshes.push_back(mf); } } } - Light *light = Object::cast_to(p_at_node); + Light3D *light = Object::cast_to(p_at_node); - if (light && light->get_bake_mode() != Light::BAKE_DISABLED) { - PlotLight pl; - Transform xf = get_global_transform().affine_inverse() * light->get_global_transform(); + if (light && light->get_bake_mode() != Light3D::BAKE_DISABLED) { - pl.local_xform = xf; - pl.light = light; - plot_lights.push_back(pl); + LightsFound lf; + lf.xform = get_global_transform().affine_inverse() * light->get_global_transform(); + lf.light = light; + lights.push_back(lf); } + + LightmapProbe *probe = Object::cast_to(p_at_node); + + if (probe) { + Transform xf = get_global_transform().affine_inverse() * probe->get_global_transform(); + probes.push_back(xf.origin); + } + for (int i = 0; i < p_at_node->get_child_count(); i++) { Node *child = p_at_node->get_child(i); if (!child->get_owner()) continue; //maybe a helper - _find_meshes_and_lights(child, plot_meshes, plot_lights); + _find_meshes_and_lights(child, meshes, lights, probes); } } -void BakedLightmap::set_hdr(bool p_enable) { - hdr = p_enable; -} +int BakedLightmap::_bsp_get_simplex_side(const Vector &p_points, const LocalVector &p_simplices, const Plane &p_plane, uint32_t p_simplex) const { -bool BakedLightmap::is_hdr() const { - return hdr; -} - -bool BakedLightmap::_bake_time(void *ud, float p_secs, float p_progress) { - - uint64_t time = OS::get_singleton()->get_ticks_usec(); - BakeTimeData *btd = (BakeTimeData *)ud; - - if (time - btd->last_step > 1000000) { - - int mins_left = p_secs / 60; - int secs_left = Math::fmod(p_secs, 60.0f); - int percent = p_progress * 100; - bool abort = bake_step_function(btd->pass + percent, btd->text + " " + vformat(RTR("%d%%"), percent) + " " + vformat(RTR("(Time Left: %d:%02d s)"), mins_left, secs_left)); - btd->last_step = time; - if (abort) - return true; + int over = 0; + int under = 0; + int coplanar = 0; + const BSPSimplex &s = p_simplices[p_simplex]; + for (int i = 0; i < 4; i++) { + const Vector3 v = p_points[s.vertices[i]]; + if (p_plane.has_point(v)) { //coplanar + coplanar++; + } else if (p_plane.is_point_over(v)) { + over++; + } else { + under++; + } } - return false; -} - -BakedLightmap::BakeError BakedLightmap::bake(Node *p_from_node, bool p_create_visual_debug) { - - String save_path; - - if (image_path.begins_with("res://")) { - save_path = image_path; + ERR_FAIL_COND_V(under == 0 && over == 0, -2); //should never happen, we discarded flat simplices before, but in any case drop it from the bsp tree and throw an error + if (under == 0) { + return 1; // all over + } else if (over == 0) { + return -1; // all under } else { - if (get_filename() != "") { - save_path = get_filename().get_base_dir(); - } else if (get_owner() && get_owner()->get_filename() != "") { - save_path = get_owner()->get_filename().get_base_dir(); - } - - if (save_path == "") { - return BAKE_ERROR_NO_SAVE_PATH; - } - if (image_path != "") { - save_path.plus_file(image_path); - } - } - { - //check for valid save path - DirAccessRef d = DirAccess::open(save_path); - if (!d) { - ERR_PRINT("Invalid Save Path '" + save_path + "'."); - return BAKE_ERROR_NO_SAVE_PATH; - } + return 0; // crossing } +} - Ref new_light_data; - new_light_data.instance(); +//#define DEBUG_BSP - Voxelizer baker; +int32_t BakedLightmap::_compute_bsp_tree(const Vector &p_points, const LocalVector &p_planes, LocalVector &planes_tested, const LocalVector &p_simplices, const LocalVector &p_simplex_indices, LocalVector &bsp_nodes) { - int bake_subdiv; - int capture_subdiv; - AABB bake_bounds; - { - bake_bounds = AABB(-extents, extents * 2.0); - int subdiv = nearest_power_of_2_templated(int(bake_bounds.get_longest_axis_size() / bake_cell_size)); - bake_bounds.size[bake_bounds.get_longest_axis_index()] = subdiv * bake_cell_size; - bake_subdiv = nearest_shift(subdiv) + 1; + //if we reach here, it means there is more than one simplex + int32_t node_index = (int32_t)bsp_nodes.size(); + bsp_nodes.push_back(BSPNode()); - capture_subdiv = bake_subdiv; - float css = bake_cell_size; - while (css < capture_cell_size && capture_subdiv > 2) { - capture_subdiv--; - css *= 2.0; - } - } + //test with all the simplex planes + Plane best_plane; + float best_plane_score = -1.0; - baker.begin_bake(bake_subdiv, bake_bounds); - - List mesh_list; - List light_list; - - _find_meshes_and_lights(p_from_node ? p_from_node : get_parent(), mesh_list, light_list); - - if (bake_begin_function) { - bake_begin_function(mesh_list.size() + light_list.size() + 1 + mesh_list.size() * 100); - } - - int step = 0; - - int pmc = 0; - - for (List::Element *E = mesh_list.front(); E; E = E->next()) { - - if (bake_step_function) { - bake_step_function(step++, RTR("Plotting Meshes: ") + " (" + itos(pmc + 1) + "/" + itos(mesh_list.size()) + ")"); - } - - pmc++; - baker.plot_mesh(E->get().local_xform, E->get().mesh, E->get().instance_materials, E->get().override_material); - } - - pmc = 0; - baker.begin_bake_light(Voxelizer::BakeQuality(bake_quality), Voxelizer::BakeMode(bake_mode), propagation, energy); - - for (List::Element *E = light_list.front(); E; E = E->next()) { - - if (bake_step_function) { - bake_step_function(step++, RTR("Plotting Lights:") + " (" + itos(pmc + 1) + "/" + itos(light_list.size()) + ")"); - } - - pmc++; - PlotLight pl = E->get(); - switch (pl.light->get_light_type()) { - case RS::LIGHT_DIRECTIONAL: { - baker.plot_light_directional(-pl.local_xform.basis.get_axis(2), pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_bake_mode() == Light::BAKE_ALL); - } break; - case RS::LIGHT_OMNI: { - baker.plot_light_omni(pl.local_xform.origin, pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_param(Light::PARAM_RANGE), pl.light->get_param(Light::PARAM_ATTENUATION), pl.light->get_bake_mode() == Light::BAKE_ALL); - } break; - case RS::LIGHT_SPOT: { - baker.plot_light_spot(pl.local_xform.origin, pl.local_xform.basis.get_axis(2), pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_param(Light::PARAM_RANGE), pl.light->get_param(Light::PARAM_ATTENUATION), pl.light->get_param(Light::PARAM_SPOT_ANGLE), pl.light->get_param(Light::PARAM_SPOT_ATTENUATION), pl.light->get_bake_mode() == Light::BAKE_ALL); - - } break; - } - } - /*if (bake_step_function) { - bake_step_function(pmc++, RTR("Finishing Plot")); - }*/ - - baker.end_bake(); - - Set used_mesh_names; - - pmc = 0; - for (List::Element *E = mesh_list.front(); E; E = E->next()) { - - String mesh_name = E->get().mesh->get_name(); - if (mesh_name == "" || mesh_name.find(":") != -1 || mesh_name.find("/") != -1) { - mesh_name = "LightMap"; - } - - if (used_mesh_names.has(mesh_name)) { - int idx = 2; - String base = mesh_name; - while (true) { - mesh_name = base + itos(idx); - if (!used_mesh_names.has(mesh_name)) - break; - idx++; - } - } - used_mesh_names.insert(mesh_name); - - pmc++; - Voxelizer::LightMapData lm; - - Error err; - if (bake_step_function) { - BakeTimeData btd; - btd.text = RTR("Lighting Meshes: ") + mesh_name + " (" + itos(pmc) + "/" + itos(mesh_list.size()) + ")"; - btd.pass = step; - btd.last_step = 0; - err = baker.make_lightmap(E->get().local_xform, E->get().mesh, bake_default_texels_per_unit, lm, _bake_time, &btd); - if (err != OK) { - bake_end_function(); - if (err == ERR_SKIP) - return BAKE_ERROR_USER_ABORTED; - return BAKE_ERROR_CANT_CREATE_IMAGE; - } - step += 100; - } else { - - err = baker.make_lightmap(E->get().local_xform, E->get().mesh, bake_default_texels_per_unit, lm); - } - - if (err == OK) { - - Ref image; - image.instance(); - - if (hdr) { - - //just save a regular image - Vector data; - int s = lm.light.size(); - data.resize(lm.light.size() * 2); - { - - uint8_t* w = data.ptrw(); - const float* r = lm.light.ptr(); - uint16_t *hfw = (uint16_t *)w.ptr(); - for (int i = 0; i < s; i++) { - hfw[i] = Math::make_half_float(r[i]); - } - } - - image->create(lm.width, lm.height, false, Image::FORMAT_RGBH, data); - - } else { - - //just save a regular image - Vector data; - int s = lm.light.size(); - data.resize(lm.light.size()); - { - - uint8_t* w = data.ptrw(); - const float* r = lm.light.ptr(); - for (int i = 0; i < s; i += 3) { - Color c(r[i + 0], r[i + 1], r[i + 2]); - c = c.to_srgb(); - w[i + 0] = CLAMP(c.r * 255, 0, 255); - w[i + 1] = CLAMP(c.g * 255, 0, 255); - w[i + 2] = CLAMP(c.b * 255, 0, 255); - } - } - - image->create(lm.width, lm.height, false, Image::FORMAT_RGB8, data); - - //This texture is saved to SRGB for two reasons: - // 1) first is so it looks better when doing the LINEAR->SRGB conversion (more accurate) - // 2) So it can be used in the GLES2 backend, which does not support linkear workflow + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { + const BSPSimplex &s = p_simplices[p_simplex_indices[i]]; + for (int j = 0; j < 4; j++) { + uint32_t plane_index = s.planes[j]; + if (planes_tested[plane_index] == node_index) { + continue; //tested this plane already } - String image_path = save_path.plus_file(mesh_name); - Ref texture; + planes_tested[plane_index] = node_index; - if (ResourceLoader::import) { + static const int face_order[4][3] = { + { 0, 1, 2 }, + { 0, 2, 3 }, + { 0, 1, 3 }, + { 1, 2, 3 } + }; - bool srgb = false; - if (false && hdr) { - //save hdr - } else { - image_path += ".png"; - print_line("image path saving png: " + image_path); - image->save_png(image_path); - srgb = true; + // despite getting rid of plane duplicates, we should still use here the actual plane to avoid numerical error + // from thinking this same simplex is intersecting rather than on a side + Vector3 v0 = p_points[s.vertices[face_order[j][0]]]; + Vector3 v1 = p_points[s.vertices[face_order[j][1]]]; + Vector3 v2 = p_points[s.vertices[face_order[j][2]]]; + + Plane plane(v0, v1, v2); + + //test with all the simplices + int over_count = 0; + int under_count = 0; + + for (uint32_t k = 0; k < p_simplex_indices.size(); k++) { + int side = _bsp_get_simplex_side(p_points, p_simplices, plane, p_simplex_indices[k]); + if (side == -2) { + continue; //this simplex is invalid, skip for now + } else if (side < 0) { + under_count++; + } else if (side > 0) { + over_count++; } - - if (!FileAccess::exists(image_path + ".import")) { - Ref config; - config.instance(); - config->set_value("remap", "importer", "texture"); - config->set_value("remap", "type", "StreamTexture"); - config->set_value("params", "compress/mode", 2); - config->set_value("params", "detect_3d", false); - config->set_value("params", "flags/repeat", false); - config->set_value("params", "flags/filter", true); - config->set_value("params", "flags/mipmaps", false); - config->set_value("params", "flags/srgb", srgb); - - config->save(image_path + ".import"); - } - - ResourceLoader::import(image_path); - texture = ResourceLoader::load(image_path); //if already loaded, it will be updated on refocus? - } else { - - image_path += ".text"; - Ref tex; - bool set_path = true; - if (ResourceCache::has(image_path)) { - tex = Ref((Resource *)ResourceCache::get(image_path)); - set_path = false; - } - - if (!tex.is_valid()) { - tex.instance(); - } - - tex->create_from_image(image); - - err = ResourceSaver::save(image_path, tex, ResourceSaver::FLAG_CHANGE_PATH); - if (set_path) { - tex->set_path(image_path); - } - texture = tex; - } - if (err != OK) { - if (bake_end_function) { - bake_end_function(); - } - ERR_FAIL_COND_V(err != OK, BAKE_ERROR_CANT_CREATE_IMAGE); } - new_light_data->add_user(E->get().path, texture, E->get().instance_idx); + if (under_count == 0 && over_count == 0) { + continue; //most likely precision issue with a flat simplex, do not try this plane + } + + if (under_count > over_count) { //make sure under is always less than over, so we can compute the same ratio + SWAP(under_count, over_count); + } + + float score = 0; //by default, score is 0 (worst) + if (over_count > 0) { + //give score mainly based on ratio (under / over), this means that this plane is splitting simplices a lot, but its balanced + score = float(under_count) / over_count; + } + + //adjusting priority over least splits, probably not a great idea + //score *= Math::sqrt(float(over_count + under_count) / p_simplex_indices.size()); //also multiply score + + if (score > best_plane_score) { + + best_plane = plane; + best_plane_score = score; + } } } - AABB bounds = AABB(-extents, extents * 2); - new_light_data->set_cell_subdiv(capture_subdiv); - new_light_data->set_bounds(bounds); - new_light_data->set_octree(baker.create_capture_octree(capture_subdiv)); - { + LocalVector indices_over; + LocalVector indices_under; - float bake_bound_size = bake_bounds.get_longest_axis_size(); - Transform to_bounds; - to_bounds.basis.scale(Vector3(bake_bound_size, bake_bound_size, bake_bound_size)); - to_bounds.origin = bounds.position; + //split again, but add to list + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { - Transform to_grid; - to_grid.basis.scale(Vector3(1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1))); + uint32_t index = p_simplex_indices[i]; + int side = _bsp_get_simplex_side(p_points, p_simplices, best_plane, index); - Transform to_cell_space = to_grid * to_bounds.affine_inverse(); - new_light_data->set_cell_space_transform(to_cell_space); - } - - if (bake_end_function) { - bake_end_function(); - } - - //create the data for visual server - - if (p_create_visual_debug) { - MultiMeshInstance *mmi = memnew(MultiMeshInstance); - mmi->set_multimesh(baker.create_debug_multimesh(Voxelizer::DEBUG_LIGHT)); - add_child(mmi); -#ifdef TOOLS_ENABLED - if (get_tree()->get_edited_scene_root() == this) { - mmi->set_owner(this); - } else { - mmi->set_owner(get_owner()); + if (side == -2) { + continue; //simplex sits on the plane, does not make sense to use it } -#else - mmi->set_owner(get_owner()); + if (side <= 0) { + indices_under.push_back(index); + } + + if (side >= 0) { + indices_over.push_back(index); + } + } + +#ifdef DEBUG_BSP + print_line("node " + itos(node_index) + " found plane: " + best_plane + " score:" + rtos(best_plane_score) + " - over " + itos(indices_over.size()) + " under " + itos(indices_under.size()) + " intersecting " + itos(intersecting)); #endif + + if (best_plane_score < 0.0 || indices_over.size() == p_simplex_indices.size() || indices_under.size() == p_simplex_indices.size()) { + ERR_FAIL_COND_V(p_simplex_indices.size() <= 1, 0); //should not happen, this is a bug + + // Failed to separate the tetrahedrons using planes + // this means Delaunay borked at some point. + // Luckily, because we are using tetrahedrons, we can resort to + // less precise but still working ways to generate the separating plane + // this will most likely look bad when interpolating, but at least it will not crash. + // and the arctifact will most likely also be very small, so too difficult to notice. + + //find the longest axis + + WARN_PRINT("Inconsistency found in triangulation while building BSP, probe interpolation quality may degrade a bit."); + + LocalVector centers; + AABB bounds_all; + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { + AABB bounds; + for (uint32_t j = 0; j < 4; j++) { + + Vector3 p = p_points[p_simplices[p_simplex_indices[i]].vertices[j]]; + if (j == 0) { + bounds.position = p; + } else { + bounds.expand_to(p); + } + } + if (i == 0) { + centers.push_back(bounds.position + bounds.size * 0.5); + } else { + bounds_all.merge_with(bounds); + } + } + Vector3::Axis longest_axis = Vector3::Axis(bounds_all.get_longest_axis_index()); + + //find the simplex that will go under + uint32_t min_d_idx = 0xFFFFFFFF; + float min_d_dist = 1e20; + + for (uint32_t i = 0; i < centers.size(); i++) { + if (centers[i][longest_axis] < min_d_dist) { + min_d_idx = i; + min_d_dist = centers[i][longest_axis]; + } + } + //rebuild best_plane and over/under arrays + best_plane = Plane(); + best_plane.normal[longest_axis] = 1.0; + best_plane.d = min_d_dist; + + indices_under.clear(); + indices_under.push_back(min_d_idx); + + indices_over.clear(); + + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { + if (i == min_d_idx) { + continue; + } + indices_over.push_back(p_simplex_indices[i]); + } } - set_light_data(new_light_data); + BSPNode node; + node.plane = best_plane; + + if (indices_under.size() == 0) { + //noting to do here + node.under = BSPNode::EMPTY_LEAF; + } else if (indices_under.size() == 1) { + node.under = -(indices_under[0] + 1); + } else { + node.under = _compute_bsp_tree(p_points, p_planes, planes_tested, p_simplices, indices_under, bsp_nodes); + } + + if (indices_over.size() == 0) { + //noting to do here + node.over = BSPNode::EMPTY_LEAF; + } else if (indices_over.size() == 1) { + node.over = -(indices_over[0] + 1); + } else { + node.over = _compute_bsp_tree(p_points, p_planes, planes_tested, p_simplices, indices_over, bsp_nodes); + } + + bsp_nodes[node_index] = node; + + return node_index; +} + +bool BakedLightmap::_lightmap_bake_step_function(float p_completion, const String &p_text, void *ud, bool p_refresh) { + + BakeStepUD *bsud = (BakeStepUD *)ud; + bool ret = false; + if (bsud->func) { + ret = bsud->func(bsud->from_percent + p_completion * (bsud->to_percent - bsud->from_percent), p_text, bsud->ud, p_refresh); + } + return ret; +} + +void BakedLightmap::_plot_triangle_into_octree(GenProbesOctree *p_cell, float p_cell_size, const Vector3 *p_triangle) { + + for (int i = 0; i < 8; i++) { + Vector3i pos = p_cell->offset; + uint32_t half_size = p_cell->size / 2; + if (i & 1) { + pos.x += half_size; + } + if (i & 2) { + pos.y += half_size; + } + if (i & 4) { + pos.z += half_size; + } + + AABB subcell; + subcell.position = Vector3(pos) * p_cell_size; + subcell.size = Vector3(half_size, half_size, half_size) * p_cell_size; + + if (!Geometry::triangle_box_overlap(subcell.position + subcell.size * 0.5, subcell.size * 0.5, p_triangle)) + continue; + + if (p_cell->children[i] == nullptr) { + GenProbesOctree *child = memnew(GenProbesOctree); + child->offset = pos; + child->size = half_size; + p_cell->children[i] = child; + } + + if (half_size > 1) { + //still levels missing + _plot_triangle_into_octree(p_cell->children[i], p_cell_size, p_triangle); + } + } +} +void BakedLightmap::_gen_new_positions_from_octree(const GenProbesOctree *p_cell, float p_cell_size, const Vector &probe_positions, LocalVector &new_probe_positions, HashMap &positions_used, const AABB &p_bounds) { + + for (int i = 0; i < 8; i++) { + + Vector3i pos = p_cell->offset; + if (i & 1) { + pos.x += p_cell->size; + } + if (i & 2) { + pos.y += p_cell->size; + } + if (i & 4) { + pos.z += p_cell->size; + } + + if (p_cell->size == 1 && !positions_used.has(pos)) { + //new position to insert! + Vector3 real_pos = p_bounds.position + Vector3(pos) * p_cell_size; + //see if a user submitted probe is too close + int ppcount = probe_positions.size(); + const Vector3 *pp = probe_positions.ptr(); + bool exists = false; + for (int j = 0; j < ppcount; j++) { + + if (pp[j].distance_to(real_pos) < CMP_EPSILON) { + exists = true; + break; + } + } + + if (!exists) { + new_probe_positions.push_back(real_pos); + } + + positions_used[pos] = true; + } + + if (p_cell->children[i] != nullptr) { + _gen_new_positions_from_octree(p_cell->children[i], p_cell_size, probe_positions, new_probe_positions, positions_used, p_bounds); + } + } +} +BakedLightmap::BakeError BakedLightmap::bake(Node *p_from_node, String p_image_data_path, Lightmapper::BakeStepFunc p_bake_step, void *p_bake_userdata) { + + if (p_image_data_path == "" && (get_light_data().is_null() || !get_light_data()->get_path().is_resource_file())) { + return BAKE_ERROR_NO_SAVE_PATH; + } + + if (p_image_data_path == "") { + + if (get_light_data().is_null()) { + return BAKE_ERROR_NO_SAVE_PATH; + } + + p_image_data_path = get_light_data()->get_path(); + if (!p_image_data_path.is_resource_file()) { + return BAKE_ERROR_NO_SAVE_PATH; + } + } + + Ref lightmapper = Lightmapper::create(); + ERR_FAIL_COND_V(lightmapper.is_null(), BAKE_ERROR_NO_LIGHTMAPPER); + + BakeStepUD bsud; + bsud.func = p_bake_step; + bsud.ud = p_bake_userdata; + bsud.from_percent = 0.2; + bsud.to_percent = 0.8; + + if (p_bake_step) { + p_bake_step(0.0, TTR("Finding meshes, lights and probes"), p_bake_userdata, true); + } + /* STEP 1, FIND MESHES, LIGHTS AND PROBES */ + Vector mesh_data; + Vector lights_found; + Vector probes_found; + AABB bounds; + { + Vector meshes_found; + _find_meshes_and_lights(p_from_node ? p_from_node : get_parent(), meshes_found, lights_found, probes_found); + + if (meshes_found.size() == 0) { + return BAKE_ERROR_NO_MESHES; + } + // create mesh data for insert + + //get the base material textures, help compute altlas size and bounds + for (int m_i = 0; m_i < meshes_found.size(); m_i++) { + + if (p_bake_step) { + float p = (float)(m_i) / meshes_found.size(); + p_bake_step(p * 0.1, vformat(TTR("Preparing geometry %d/%d"), m_i, meshes_found.size()), p_bake_userdata, false); + } + + MeshesFound &mf = meshes_found.write[m_i]; + + Size2i lightmap_size = mf.mesh->get_lightmap_size_hint() * mf.lightmap_scale; + Vector overrides; + overrides.resize(mf.overrides.size()); + for (int i = 0; i < mf.overrides.size(); i++) { + if (mf.overrides[i].is_valid()) { + overrides.write[i] = mf.overrides[i]->get_rid(); + } + } + TypedArray images = RS::get_singleton()->bake_render_uv2(mf.mesh->get_rid(), overrides, lightmap_size); + + ERR_FAIL_COND_V(images.empty(), BAKE_ERROR_CANT_CREATE_IMAGE); + + Ref albedo = images[RS::BAKE_CHANNEL_ALBEDO_ALPHA]; + Ref orm = images[RS::BAKE_CHANNEL_ORM]; + + //multiply albedo by metal + + Lightmapper::MeshData md; + + { + Dictionary d; + d["path"] = mf.node_path; + if (mf.subindex >= 0) { + d["subindex"] = mf.subindex; + } + md.userdata = d; + } + + { + + if (albedo->get_format() != Image::FORMAT_RGBA8) { + albedo->convert(Image::FORMAT_RGBA8); + } + if (orm->get_format() != Image::FORMAT_RGBA8) { + orm->convert(Image::FORMAT_RGBA8); + } + Vector albedo_alpha = albedo->get_data(); + Vector orm_data = orm->get_data(); + + Vector albedom; + uint32_t len = albedo_alpha.size(); + albedom.resize(len); + const uint8_t *r_aa = albedo_alpha.ptr(); + const uint8_t *r_orm = orm_data.ptr(); + uint8_t *w_albedo = albedom.ptrw(); + + for (uint32_t i = 0; i < len; i += 4) { + w_albedo[i + 0] = uint8_t(CLAMP(float(r_aa[i + 0]) * (1.0 - float(r_orm[i + 2] / 255.0)), 0, 255)); + w_albedo[i + 1] = uint8_t(CLAMP(float(r_aa[i + 1]) * (1.0 - float(r_orm[i + 2] / 255.0)), 0, 255)); + w_albedo[i + 2] = uint8_t(CLAMP(float(r_aa[i + 2]) * (1.0 - float(r_orm[i + 2] / 255.0)), 0, 255)); + w_albedo[i + 3] = 255; + } + + md.albedo_on_uv2.instance(); + md.albedo_on_uv2->create(lightmap_size.width, lightmap_size.height, false, Image::FORMAT_RGBA8, albedom); + } + + md.emission_on_uv2 = images[RS::BAKE_CHANNEL_EMISSION]; + if (md.emission_on_uv2->get_format() != Image::FORMAT_RGBAH) { + md.emission_on_uv2->convert(Image::FORMAT_RGBAH); + } + + //get geometry + + Basis normal_xform = mf.xform.basis.inverse().transposed(); + + for (int i = 0; i < mf.mesh->get_surface_count(); i++) { + if (mf.mesh->surface_get_primitive_type(i) != Mesh::PRIMITIVE_TRIANGLES) { + continue; + } + Array a = mf.mesh->surface_get_arrays(i); + + Vector vertices = a[Mesh::ARRAY_VERTEX]; + const Vector3 *vr = vertices.ptr(); + Vector uv = a[Mesh::ARRAY_TEX_UV2]; + const Vector2 *uvr = nullptr; + Vector normals = a[Mesh::ARRAY_NORMAL]; + const Vector3 *nr = nullptr; + Vector index = a[Mesh::ARRAY_INDEX]; + + ERR_CONTINUE(uv.size() == 0); + ERR_CONTINUE(normals.size() == 0); + + uvr = uv.ptr(); + nr = normals.ptr(); + + int facecount; + const int *ir = nullptr; + + if (index.size()) { + + facecount = index.size() / 3; + ir = index.ptr(); + } else { + facecount = vertices.size() / 3; + } + + for (int j = 0; j < facecount; j++) { + + uint32_t vidx[3]; + + if (ir) { + for (int k = 0; k < 3; k++) { + vidx[k] = ir[j * 3 + k]; + } + } else { + for (int k = 0; k < 3; k++) { + vidx[k] = j * 3 + k; + } + } + + for (int k = 0; k < 3; k++) { + Vector3 v = mf.xform.xform(vr[vidx[k]]); + if (bounds == AABB()) { + bounds.position = v; + } else { + bounds.expand_to(v); + } + md.points.push_back(v); + + md.uv2.push_back(uvr[vidx[k]]); + md.normal.push_back(normal_xform.xform(nr[vidx[k]]).normalized()); + } + } + } + + mesh_data.push_back(md); + } + } + + /* STEP 2, CREATE PROBES */ + + if (p_bake_step) { + p_bake_step(0.3, TTR("Creating probes"), p_bake_userdata, true); + } + + //bounds need to include the user probes + for (int i = 0; i < probes_found.size(); i++) { + bounds.expand_to(probes_found[i]); + } + + bounds.grow_by(bounds.size.length() * 0.001); + + if (gen_probes == GENERATE_PROBES_DISABLED) { + // generate 8 probes on bound endpoints + for (int i = 0; i < 8; i++) { + probes_found.push_back(bounds.get_endpoint(i)); + } + } else { + // detect probes from geometry + static const int subdiv_values[6] = { 0, 4, 8, 16, 32 }; + int subdiv = subdiv_values[gen_probes]; + + float subdiv_cell_size; + Vector3i bound_limit; + { + int longest_axis = bounds.get_longest_axis_index(); + subdiv_cell_size = bounds.size[longest_axis] / subdiv; + int axis_n1 = (longest_axis + 1) % 3; + int axis_n2 = (longest_axis + 2) % 3; + + bound_limit[longest_axis] = subdiv; + bound_limit[axis_n1] = int(Math::ceil(bounds.size[axis_n1] / subdiv_cell_size)); + bound_limit[axis_n2] = int(Math::ceil(bounds.size[axis_n2] / subdiv_cell_size)); + //compensate bounds + bounds.size[axis_n1] = bound_limit[axis_n1] * subdiv_cell_size; + bounds.size[axis_n2] = bound_limit[axis_n2] * subdiv_cell_size; + } + + GenProbesOctree octree; + octree.size = subdiv; + + for (int i = 0; i < mesh_data.size(); i++) { + if (p_bake_step) { + float p = (float)(i) / mesh_data.size(); + p_bake_step(0.3 + p * 0.1, vformat(TTR("Creating probes from mesh %d/%d"), i, mesh_data.size()), p_bake_userdata, false); + } + + for (int j = 0; j < mesh_data[i].points.size(); j += 3) { + Vector3 points[3] = { mesh_data[i].points[j + 0] - bounds.position, mesh_data[i].points[j + 1] - bounds.position, mesh_data[i].points[j + 2] - bounds.position }; + _plot_triangle_into_octree(&octree, subdiv_cell_size, points); + } + } + + LocalVector new_probe_positions; + HashMap positions_used; + for (uint32_t i = 0; i < 8; i++) { //insert bounding endpoints + Vector3i pos; + if (i & 1) { + pos.x += bound_limit.x; + } + if (i & 2) { + pos.y += bound_limit.y; + } + if (i & 4) { + pos.z += bound_limit.z; + } + + positions_used[pos] = true; + Vector3 real_pos = bounds.position + Vector3(pos) * subdiv_cell_size; //use same formula for numerical stability + new_probe_positions.push_back(real_pos); + } + //skip first level, since probes are always added at bounds endpoints anyway (code above this) + for (int i = 0; i < 8; i++) { + + if (octree.children[i]) { + _gen_new_positions_from_octree(octree.children[i], subdiv_cell_size, probes_found, new_probe_positions, positions_used, bounds); + } + } + + for (uint32_t i = 0; i < new_probe_positions.size(); i++) { + probes_found.push_back(new_probe_positions[i]); + } + } + + // Add everything to lightmapper + if (p_bake_step) { + p_bake_step(0.4, TTR("Preparing Lightmapper"), p_bake_userdata, true); + } + + { + + for (int i = 0; i < mesh_data.size(); i++) { + lightmapper->add_mesh(mesh_data[i]); + } + for (int i = 0; i < lights_found.size(); i++) { + Light3D *light = lights_found[i].light; + Transform xf = lights_found[i].xform; + + if (Object::cast_to(light)) { + DirectionalLight3D *l = Object::cast_to(light); + lightmapper->add_directional_light(light->get_bake_mode() == Light3D::BAKE_ALL, -xf.basis.get_axis(Vector3::AXIS_Z).normalized(), l->get_color(), l->get_param(Light3D::PARAM_ENERGY), l->get_param(Light3D::PARAM_SIZE)); + } else if (Object::cast_to(light)) { + OmniLight3D *l = Object::cast_to(light); + lightmapper->add_omni_light(light->get_bake_mode() == Light3D::BAKE_ALL, xf.origin, l->get_color(), l->get_param(Light3D::PARAM_ENERGY), l->get_param(Light3D::PARAM_RANGE), l->get_param(Light3D::PARAM_ATTENUATION), l->get_param(Light3D::PARAM_SIZE)); + } else if (Object::cast_to(light)) { + SpotLight3D *l = Object::cast_to(light); + lightmapper->add_spot_light(light->get_bake_mode() == Light3D::BAKE_ALL, xf.origin, -xf.basis.get_axis(Vector3::AXIS_Z).normalized(), l->get_color(), l->get_param(Light3D::PARAM_ENERGY), l->get_param(Light3D::PARAM_RANGE), l->get_param(Light3D::PARAM_ATTENUATION), l->get_param(Light3D::PARAM_SPOT_ANGLE), l->get_param(Light3D::PARAM_SPOT_ATTENUATION), l->get_param(Light3D::PARAM_SIZE)); + } + } + for (int i = 0; i < probes_found.size(); i++) { + lightmapper->add_probe(probes_found[i]); + } + } + + Ref environment_image; + Basis environment_transform; + + // Add everything to lightmapper + if (environment_mode != ENVIRONMENT_MODE_DISABLED) { + if (p_bake_step) { + p_bake_step(4.1, TTR("Preparing Environment"), p_bake_userdata, true); + } + + environment_transform = get_global_transform().basis; + + switch (environment_mode) { + case ENVIRONMENT_MODE_DISABLED: { + //nothing + } break; + case ENVIRONMENT_MODE_SCENE: { + Ref world = get_world_3d(); + if (world.is_valid()) { + Ref env = world->get_environment(); + if (env.is_null()) { + env = world->get_fallback_environment(); + } + + if (env.is_valid()) { + environment_image = RS::get_singleton()->environment_bake_panorama(env->get_rid(), true, Size2i(128, 64)); + } + } + } break; + case ENVIRONMENT_MODE_CUSTOM_SKY: { + if (environment_custom_sky.is_valid()) { + environment_image = RS::get_singleton()->sky_bake_panorama(environment_custom_sky->get_rid(), environment_custom_energy, true, Size2i(128, 64)); + } + + } break; + case ENVIRONMENT_MODE_CUSTOM_COLOR: { + environment_image.instance(); + environment_image->create(128, 64, false, Image::FORMAT_RGBAF); + Color c = environment_custom_color; + c.r *= environment_custom_energy; + c.g *= environment_custom_energy; + c.b *= environment_custom_energy; + for (int i = 0; i < 128; i++) { + for (int j = 0; j < 64; j++) { + environment_image->set_pixel(i, j, c); + } + } + + } break; + } + } + + Lightmapper::BakeError bake_err = lightmapper->bake(Lightmapper::BakeQuality(bake_quality), use_denoiser, bounces, bias, max_texture_size, directional, Lightmapper::GenerateProbes(gen_probes), environment_image, environment_transform, _lightmap_bake_step_function, &bsud); + + if (bake_err == Lightmapper::BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES) { + return BAKE_ERROR_MESHES_INVALID; + } + + /* POSTBAKE: Save Textures */ + + Ref texture; + { + + Vector> images; + for (int i = 0; i < lightmapper->get_bake_texture_count(); i++) { + images.push_back(lightmapper->get_bake_texture(i)); + } + //we assume they are all the same, so lets create a large one for saving + Ref large_image; + large_image.instance(); + + large_image->create(images[0]->get_width(), images[0]->get_height() * images.size(), false, images[0]->get_format()); + + for (int i = 0; i < lightmapper->get_bake_texture_count(); i++) { + large_image->blit_rect(images[i], Rect2(0, 0, images[i]->get_width(), images[i]->get_height()), Point2(0, images[i]->get_height() * i)); + } + + String base_path = p_image_data_path.get_basename() + ".exr"; + + Ref config; + + config.instance(); + if (FileAccess::exists(base_path + ".import")) { + + config->load(base_path + ".import"); + } + + config->set_value("remap", "importer", "2d_array_texture"); + config->set_value("remap", "type", "StreamTexture2DArray"); + if (!config->has_section_key("params", "compress/mode")) { + config->set_value("params", "compress/mode", 2); //user may want another compression, so leave it be + } + config->set_value("params", "compress/channel_pack", 1); + config->set_value("params", "mipmaps/generate", false); + config->set_value("params", "slices/horizontal", 1); + config->set_value("params", "slices/vertical", images.size()); + + config->save(base_path + ".import"); + + Error err = large_image->save_exr(base_path, false); + ERR_FAIL_COND_V(err, BAKE_ERROR_CANT_CREATE_IMAGE); + ResourceLoader::import(base_path); + Ref t = ResourceLoader::load(base_path); //if already loaded, it will be updated on refocus? + ERR_FAIL_COND_V(t.is_null(), BAKE_ERROR_CANT_CREATE_IMAGE); + texture = t; + } + + /* POSTBAKE: Save Light Data */ + + Ref data; + if (get_light_data().is_valid()) { + data = get_light_data(); + set_light_data(Ref()); //clear + data->clear(); + } else { + data.instance(); + } + + data->set_light_texture(texture); + data->set_uses_spherical_harmonics(directional); + + for (int i = 0; i < lightmapper->get_bake_mesh_count(); i++) { + Dictionary d = lightmapper->get_bake_mesh_userdata(i); + NodePath np = d["path"]; + int32_t subindex = -1; + if (d.has("subindex")) { + subindex = d["subindex"]; + } + + Rect2 uv_scale = lightmapper->get_bake_mesh_uv_scale(i); + int slice_index = lightmapper->get_bake_mesh_texture_slice(i); + data->add_user(np, uv_scale, slice_index, subindex); + } + + { + // create tetrahedrons + Vector points; + Vector sh; + points.resize(lightmapper->get_bake_probe_count()); + sh.resize(lightmapper->get_bake_probe_count() * 9); + for (int i = 0; i < lightmapper->get_bake_probe_count(); i++) { + points.write[i] = lightmapper->get_bake_probe_point(i); + Vector colors = lightmapper->get_bake_probe_sh(i); + ERR_CONTINUE(colors.size() != 9); + for (int j = 0; j < 9; j++) { + sh.write[i * 9 + j] = colors[j]; + } + } + + //Obtain solved simplices + + if (p_bake_step) { + p_bake_step(0.8, TTR("Generating Probe Volumes"), p_bake_userdata, true); + } + Vector solved_simplices = Delaunay3D::tetrahedralize(points); + + LocalVector bsp_simplices; + LocalVector bsp_planes; + LocalVector bsp_simplex_indices; + PackedInt32Array tetrahedrons; + + for (int i = 0; i < solved_simplices.size(); i++) { + + //Prepare a special representation of the simplex, which uses a BSP Tree + BSPSimplex bsp_simplex; + for (int j = 0; j < 4; j++) { + bsp_simplex.vertices[j] = solved_simplices[i].points[j]; + } + for (int j = 0; j < 4; j++) { + static const int face_order[4][3] = { + { 0, 1, 2 }, + { 0, 2, 3 }, + { 0, 1, 3 }, + { 1, 2, 3 } + }; + Vector3 a = points[solved_simplices[i].points[face_order[j][0]]]; + Vector3 b = points[solved_simplices[i].points[face_order[j][1]]]; + Vector3 c = points[solved_simplices[i].points[face_order[j][2]]]; + + //store planes in an array, but ensure they are reused, to speed up processing + + Plane p(a, b, c); + int plane_index = -1; + for (uint32_t k = 0; k < bsp_planes.size(); k++) { + + if (bsp_planes[k].is_equal_approx_any_side(p)) { + plane_index = k; + break; + } + } + + if (plane_index == -1) { + plane_index = bsp_planes.size(); + bsp_planes.push_back(p); + } + + bsp_simplex.planes[j] = plane_index; + + //also fill simplex array + tetrahedrons.push_back(solved_simplices[i].points[j]); + } + + bsp_simplex_indices.push_back(bsp_simplices.size()); + bsp_simplices.push_back(bsp_simplex); + } + +//#define DEBUG_SIMPLICES_AS_OBJ_FILE +#ifdef DEBUG_SIMPLICES_AS_OBJ_FILE + { + FileAccessRef f = FileAccess::open("res://bsp.obj", FileAccess::WRITE); + for (uint32_t i = 0; i < bsp_simplices.size(); i++) { + f->store_line("o Simplex" + itos(i)); + for (int j = 0; j < 4; j++) { + f->store_line(vformat("v %f %f %f", points[bsp_simplices[i].vertices[j]].x, points[bsp_simplices[i].vertices[j]].y, points[bsp_simplices[i].vertices[j]].z)); + } + static const int face_order[4][3] = { + { 1, 2, 3 }, + { 1, 3, 4 }, + { 1, 2, 4 }, + { 2, 3, 4 } + }; + + for (int j = 0; j < 4; j++) { + f->store_line(vformat("f %d %d %d", 4 * i + face_order[j][0], 4 * i + face_order[j][1], 4 * i + face_order[j][2])); + } + } + f->close(); + } +#endif + + LocalVector bsp_nodes; + LocalVector planes_tested; + planes_tested.resize(bsp_planes.size()); + for (uint32_t i = 0; i < planes_tested.size(); i++) { + planes_tested[i] = 0x7FFFFFFF; + } + + if (p_bake_step) { + p_bake_step(0.9, TTR("Generating Probe Acceleration Structures"), p_bake_userdata, true); + } + + _compute_bsp_tree(points, bsp_planes, planes_tested, bsp_simplices, bsp_simplex_indices, bsp_nodes); + + PackedInt32Array bsp_array; + bsp_array.resize(bsp_nodes.size() * 6); // six 32 bits values used for each BSP node + { + float *fptr = (float *)bsp_array.ptrw(); + int32_t *iptr = (int32_t *)bsp_array.ptrw(); + for (uint32_t i = 0; i < bsp_nodes.size(); i++) { + fptr[i * 6 + 0] = bsp_nodes[i].plane.normal.x; + fptr[i * 6 + 1] = bsp_nodes[i].plane.normal.y; + fptr[i * 6 + 2] = bsp_nodes[i].plane.normal.z; + fptr[i * 6 + 3] = bsp_nodes[i].plane.d; + iptr[i * 6 + 4] = bsp_nodes[i].over; + iptr[i * 6 + 5] = bsp_nodes[i].under; + } +//#define DEBUG_BSP_TREE +#ifdef DEBUG_BSP_TREE + FileAccessRef f = FileAccess::open("res://bsp.txt", FileAccess::WRITE); + for (uint32_t i = 0; i < bsp_nodes.size(); i++) { + f->store_line(itos(i) + " - plane: " + bsp_nodes[i].plane + " over: " + itos(bsp_nodes[i].over) + " under: " + itos(bsp_nodes[i].under)); + } +#endif + } + + /* Obtain the colors from the images, they will be re-created as cubemaps on the server, depending on the driver */ + + data->set_capture_data(bounds, interior, points, sh, tetrahedrons, bsp_array); + /* Compute a BSP tree of the simplices, so it's easy to find the exact one */ + } + + Error err = ResourceSaver::save(p_image_data_path, data); + data->set_path(p_image_data_path); + + if (err != OK) { + return BAKE_ERROR_CANT_CREATE_IMAGE; + } + + set_light_data(data); return BAKE_ERROR_OK; } void BakedLightmap::_notification(int p_what) { - if (p_what == NOTIFICATION_READY) { + if (p_what == NOTIFICATION_POST_ENTER_TREE) { if (light_data.is_valid()) { _assign_lightmaps(); } - request_ready(); //will need ready again if re-enters tree } if (p_what == NOTIFICATION_EXIT_TREE) { @@ -667,20 +1236,18 @@ void BakedLightmap::_assign_lightmaps() { ERR_FAIL_COND(!light_data.is_valid()); for (int i = 0; i < light_data->get_user_count(); i++) { - Ref lightmap = light_data->get_user_lightmap(i); - ERR_CONTINUE(!lightmap.is_valid()); Node *node = get_node(light_data->get_user_path(i)); - int instance_idx = light_data->get_user_instance(i); + int instance_idx = light_data->get_user_sub_instance(i); if (instance_idx >= 0) { RID instance = node->call("get_bake_mesh_instance", instance_idx); if (instance.is_valid()) { - RS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), lightmap->get_rid()); + RS::get_singleton()->instance_geometry_set_lightmap(instance, get_instance(), light_data->get_user_lightmap_uv_scale(i), light_data->get_user_lightmap_slice_index(i)); } } else { - VisualInstance *vi = Object::cast_to(node); + VisualInstance3D *vi = Object::cast_to(node); ERR_CONTINUE(!vi); - RS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), lightmap->get_rid()); + RS::get_singleton()->instance_geometry_set_lightmap(vi->get_instance(), get_instance(), light_data->get_user_lightmap_uv_scale(i), light_data->get_user_lightmap_slice_index(i)); } } } @@ -689,16 +1256,16 @@ void BakedLightmap::_clear_lightmaps() { ERR_FAIL_COND(!light_data.is_valid()); for (int i = 0; i < light_data->get_user_count(); i++) { Node *node = get_node(light_data->get_user_path(i)); - int instance_idx = light_data->get_user_instance(i); + int instance_idx = light_data->get_user_sub_instance(i); if (instance_idx >= 0) { RID instance = node->call("get_bake_mesh_instance", instance_idx); if (instance.is_valid()) { - RS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), RID()); + RS::get_singleton()->instance_geometry_set_lightmap(instance, RID(), Rect2(), 0); } } else { - VisualInstance *vi = Object::cast_to(node); + VisualInstance3D *vi = Object::cast_to(node); ERR_CONTINUE(!vi); - RS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), RID()); + RS::get_singleton()->instance_geometry_set_lightmap(vi->get_instance(), RID(), Rect2(), 0); } } } @@ -719,6 +1286,8 @@ void BakedLightmap::set_light_data(const Ref &p_data) { _assign_lightmaps(); } } + + update_gizmo(); } Ref BakedLightmap::get_light_data() const { @@ -726,28 +1295,6 @@ Ref BakedLightmap::get_light_data() const { return light_data; } -void BakedLightmap::_debug_bake() { - bake(get_parent(), true); -} - -void BakedLightmap::set_propagation(float p_propagation) { - propagation = p_propagation; -} - -float BakedLightmap::get_propagation() const { - - return propagation; -} - -void BakedLightmap::set_energy(float p_energy) { - energy = p_energy; -} - -float BakedLightmap::get_energy() const { - - return energy; -} - void BakedLightmap::set_bake_quality(BakeQuality p_quality) { bake_quality = p_quality; } @@ -756,109 +1303,206 @@ BakedLightmap::BakeQuality BakedLightmap::get_bake_quality() const { return bake_quality; } -void BakedLightmap::set_bake_mode(BakeMode p_mode) { - bake_mode = p_mode; -} - -BakedLightmap::BakeMode BakedLightmap::get_bake_mode() const { - return bake_mode; -} - -void BakedLightmap::set_image_path(const String &p_path) { - image_path = p_path; -} - -String BakedLightmap::get_image_path() const { - return image_path; -} - AABB BakedLightmap::get_aabb() const { - return AABB(-extents, extents * 2); + return AABB(); } Vector BakedLightmap::get_faces(uint32_t p_usage_flags) const { return Vector(); } +void BakedLightmap::set_use_denoiser(bool p_enable) { + + use_denoiser = p_enable; +} + +bool BakedLightmap::is_using_denoiser() const { + + return use_denoiser; +} + +void BakedLightmap::set_directional(bool p_enable) { + directional = p_enable; +} + +bool BakedLightmap::is_directional() const { + return directional; +} + +void BakedLightmap::set_interior(bool p_enable) { + interior = p_enable; +} +bool BakedLightmap::is_interior() const { + return interior; +} + +void BakedLightmap::set_environment_mode(EnvironmentMode p_mode) { + environment_mode = p_mode; + _change_notify(); +} + +BakedLightmap::EnvironmentMode BakedLightmap::get_environment_mode() const { + return environment_mode; +} + +void BakedLightmap::set_environment_custom_sky(const Ref &p_sky) { + environment_custom_sky = p_sky; +} + +Ref BakedLightmap::get_environment_custom_sky() const { + return environment_custom_sky; +} + +void BakedLightmap::set_environment_custom_color(const Color &p_color) { + environment_custom_color = p_color; +} +Color BakedLightmap::get_environment_custom_color() const { + return environment_custom_color; +} + +void BakedLightmap::set_environment_custom_energy(float p_energy) { + environment_custom_energy = p_energy; +} +float BakedLightmap::get_environment_custom_energy() const { + return environment_custom_energy; +} + +void BakedLightmap::set_bounces(int p_bounces) { + ERR_FAIL_COND(p_bounces < 0 || p_bounces > 16); + bounces = p_bounces; +} + +int BakedLightmap::get_bounces() const { + return bounces; +} + +void BakedLightmap::set_bias(float p_bias) { + ERR_FAIL_COND(p_bias < 0.00001); + bias = p_bias; +} + +float BakedLightmap::get_bias() const { + return bias; +} + +void BakedLightmap::set_max_texture_size(int p_size) { + ERR_FAIL_COND(p_size < 2048); + max_texture_size = p_size; +} + +int BakedLightmap::get_max_texture_size() const { + return max_texture_size; +} + +void BakedLightmap::set_generate_probes(GenerateProbes p_generate_probes) { + gen_probes = p_generate_probes; +} + +BakedLightmap::GenerateProbes BakedLightmap::get_generate_probes() const { + return gen_probes; +} + +void BakedLightmap::_validate_property(PropertyInfo &property) const { + if (property.name == "environment_custom_sky" && environment_mode != ENVIRONMENT_MODE_CUSTOM_SKY) { + property.usage = 0; + } + if (property.name == "environment_custom_color" && environment_mode != ENVIRONMENT_MODE_CUSTOM_COLOR) { + property.usage = 0; + } + if (property.name == "environment_custom_energy" && environment_mode != ENVIRONMENT_MODE_CUSTOM_COLOR && environment_mode != ENVIRONMENT_MODE_CUSTOM_SKY) { + property.usage = 0; + } +} + void BakedLightmap::_bind_methods() { ClassDB::bind_method(D_METHOD("set_light_data", "data"), &BakedLightmap::set_light_data); ClassDB::bind_method(D_METHOD("get_light_data"), &BakedLightmap::get_light_data); - ClassDB::bind_method(D_METHOD("set_bake_cell_size", "bake_cell_size"), &BakedLightmap::set_bake_cell_size); - ClassDB::bind_method(D_METHOD("get_bake_cell_size"), &BakedLightmap::get_bake_cell_size); - - ClassDB::bind_method(D_METHOD("set_capture_cell_size", "capture_cell_size"), &BakedLightmap::set_capture_cell_size); - ClassDB::bind_method(D_METHOD("get_capture_cell_size"), &BakedLightmap::get_capture_cell_size); - ClassDB::bind_method(D_METHOD("set_bake_quality", "bake_quality"), &BakedLightmap::set_bake_quality); ClassDB::bind_method(D_METHOD("get_bake_quality"), &BakedLightmap::get_bake_quality); - ClassDB::bind_method(D_METHOD("set_bake_mode", "bake_mode"), &BakedLightmap::set_bake_mode); - ClassDB::bind_method(D_METHOD("get_bake_mode"), &BakedLightmap::get_bake_mode); + ClassDB::bind_method(D_METHOD("set_bounces", "bounces"), &BakedLightmap::set_bounces); + ClassDB::bind_method(D_METHOD("get_bounces"), &BakedLightmap::get_bounces); - ClassDB::bind_method(D_METHOD("set_extents", "extents"), &BakedLightmap::set_extents); - ClassDB::bind_method(D_METHOD("get_extents"), &BakedLightmap::get_extents); + ClassDB::bind_method(D_METHOD("set_generate_probes", "subdivision"), &BakedLightmap::set_generate_probes); + ClassDB::bind_method(D_METHOD("get_generate_probes"), &BakedLightmap::get_generate_probes); - ClassDB::bind_method(D_METHOD("set_bake_default_texels_per_unit", "texels"), &BakedLightmap::set_bake_default_texels_per_unit); - ClassDB::bind_method(D_METHOD("get_bake_default_texels_per_unit"), &BakedLightmap::get_bake_default_texels_per_unit); + ClassDB::bind_method(D_METHOD("set_bias", "bias"), &BakedLightmap::set_bias); + ClassDB::bind_method(D_METHOD("get_bias"), &BakedLightmap::get_bias); - ClassDB::bind_method(D_METHOD("set_propagation", "propagation"), &BakedLightmap::set_propagation); - ClassDB::bind_method(D_METHOD("get_propagation"), &BakedLightmap::get_propagation); + ClassDB::bind_method(D_METHOD("set_environment_mode", "mode"), &BakedLightmap::set_environment_mode); + ClassDB::bind_method(D_METHOD("get_environment_mode"), &BakedLightmap::get_environment_mode); - ClassDB::bind_method(D_METHOD("set_energy", "energy"), &BakedLightmap::set_energy); - ClassDB::bind_method(D_METHOD("get_energy"), &BakedLightmap::get_energy); + ClassDB::bind_method(D_METHOD("set_environment_custom_sky", "sky"), &BakedLightmap::set_environment_custom_sky); + ClassDB::bind_method(D_METHOD("get_environment_custom_sky"), &BakedLightmap::get_environment_custom_sky); - ClassDB::bind_method(D_METHOD("set_hdr", "hdr"), &BakedLightmap::set_hdr); - ClassDB::bind_method(D_METHOD("is_hdr"), &BakedLightmap::is_hdr); + ClassDB::bind_method(D_METHOD("set_environment_custom_color", "color"), &BakedLightmap::set_environment_custom_color); + ClassDB::bind_method(D_METHOD("get_environment_custom_color"), &BakedLightmap::get_environment_custom_color); - ClassDB::bind_method(D_METHOD("set_image_path", "image_path"), &BakedLightmap::set_image_path); - ClassDB::bind_method(D_METHOD("get_image_path"), &BakedLightmap::get_image_path); + ClassDB::bind_method(D_METHOD("set_environment_custom_energy", "energy"), &BakedLightmap::set_environment_custom_energy); + ClassDB::bind_method(D_METHOD("get_environment_custom_energy"), &BakedLightmap::get_environment_custom_energy); - ClassDB::bind_method(D_METHOD("bake", "from_node", "create_visual_debug"), &BakedLightmap::bake, DEFVAL(Variant()), DEFVAL(false)); - ClassDB::bind_method(D_METHOD("debug_bake"), &BakedLightmap::_debug_bake); - ClassDB::set_method_flags(get_class_static(), _scs_create("debug_bake"), METHOD_FLAGS_DEFAULT | METHOD_FLAG_EDITOR); + ClassDB::bind_method(D_METHOD("set_max_texture_size", "max_texture_size"), &BakedLightmap::set_max_texture_size); + ClassDB::bind_method(D_METHOD("get_max_texture_size"), &BakedLightmap::get_max_texture_size); - ADD_GROUP("Bake", "bake_"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_cell_size", PROPERTY_HINT_RANGE, "0.01,64,0.01"), "set_bake_cell_size", "get_bake_cell_size"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "bake_quality", PROPERTY_HINT_ENUM, "Low,Medium,High"), "set_bake_quality", "get_bake_quality"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "bake_mode", PROPERTY_HINT_ENUM, "ConeTrace,RayTrace"), "set_bake_mode", "get_bake_mode"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_propagation", PROPERTY_HINT_RANGE, "0,1,0.01"), "set_propagation", "get_propagation"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_energy", PROPERTY_HINT_RANGE, "0,32,0.01"), "set_energy", "get_energy"); - ADD_PROPERTY(PropertyInfo(Variant::BOOL, "bake_hdr"), "set_hdr", "is_hdr"); - ADD_PROPERTY(PropertyInfo(Variant::VECTOR3, "bake_extents"), "set_extents", "get_extents"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_default_texels_per_unit"), "set_bake_default_texels_per_unit", "get_bake_default_texels_per_unit"); - ADD_GROUP("Capture", "capture_"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "capture_cell_size", PROPERTY_HINT_RANGE, "0.01,64,0.01"), "set_capture_cell_size", "get_capture_cell_size"); + ClassDB::bind_method(D_METHOD("set_use_denoiser", "use_denoiser"), &BakedLightmap::set_use_denoiser); + ClassDB::bind_method(D_METHOD("is_using_denoiser"), &BakedLightmap::is_using_denoiser); + + ClassDB::bind_method(D_METHOD("set_interior", "enable"), &BakedLightmap::set_interior); + ClassDB::bind_method(D_METHOD("is_interior"), &BakedLightmap::is_interior); + + ClassDB::bind_method(D_METHOD("set_directional", "directional"), &BakedLightmap::set_directional); + ClassDB::bind_method(D_METHOD("is_directional"), &BakedLightmap::is_directional); + + // ClassDB::bind_method(D_METHOD("bake", "from_node"), &BakedLightmap::bake, DEFVAL(Variant())); + + ADD_GROUP("Tweaks", ""); + ADD_PROPERTY(PropertyInfo(Variant::INT, "quality", PROPERTY_HINT_ENUM, "Low,Medium,High,Ultra"), "set_bake_quality", "get_bake_quality"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "bounces", PROPERTY_HINT_RANGE, "0,16,1"), "set_bounces", "get_bounces"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "directional"), "set_directional", "is_directional"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "interior"), "set_interior", "is_interior"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "use_denoiser"), "set_use_denoiser", "is_using_denoiser"); + ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bias", PROPERTY_HINT_RANGE, "0.00001,0.1,0.00001,or_greater"), "set_bias", "get_bias"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "max_texture_size"), "set_max_texture_size", "get_max_texture_size"); + ADD_GROUP("Environment", "environment_"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "environment_mode", PROPERTY_HINT_ENUM, "Disabled,Scene,Custom Sky,Custom Color"), "set_environment_mode", "get_environment_mode"); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "environment_custom_sky", PROPERTY_HINT_RESOURCE_TYPE, "Sky"), "set_environment_custom_sky", "get_environment_custom_sky"); + ADD_PROPERTY(PropertyInfo(Variant::COLOR, "environment_custom_color", PROPERTY_HINT_COLOR_NO_ALPHA), "set_environment_custom_color", "get_environment_custom_color"); + ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "environment_custom_energy", PROPERTY_HINT_RANGE, "0,64,0.01"), "set_environment_custom_energy", "get_environment_custom_energy"); + ADD_GROUP("Gen Probes", "generate_probes_"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "generate_probes_subdiv", PROPERTY_HINT_ENUM, "Disabled,4,8,16,32"), "set_generate_probes", "get_generate_probes"); ADD_GROUP("Data", ""); - ADD_PROPERTY(PropertyInfo(Variant::STRING, "image_path", PROPERTY_HINT_DIR), "set_image_path", "get_image_path"); ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "light_data", PROPERTY_HINT_RESOURCE_TYPE, "BakedLightmapData"), "set_light_data", "get_light_data"); BIND_ENUM_CONSTANT(BAKE_QUALITY_LOW); BIND_ENUM_CONSTANT(BAKE_QUALITY_MEDIUM); BIND_ENUM_CONSTANT(BAKE_QUALITY_HIGH); - BIND_ENUM_CONSTANT(BAKE_MODE_CONE_TRACE); - BIND_ENUM_CONSTANT(BAKE_MODE_RAY_TRACE); BIND_ENUM_CONSTANT(BAKE_ERROR_OK); BIND_ENUM_CONSTANT(BAKE_ERROR_NO_SAVE_PATH); BIND_ENUM_CONSTANT(BAKE_ERROR_NO_MESHES); BIND_ENUM_CONSTANT(BAKE_ERROR_CANT_CREATE_IMAGE); BIND_ENUM_CONSTANT(BAKE_ERROR_USER_ABORTED); + + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_DISABLED); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_SCENE); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_CUSTOM_SKY); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_CUSTOM_COLOR); } BakedLightmap::BakedLightmap() { - extents = Vector3(10, 10, 10); - bake_default_texels_per_unit = 20; - bake_cell_size = 0.25; - capture_cell_size = 0.5; + environment_mode = ENVIRONMENT_MODE_DISABLED; + environment_custom_color = Color(0.2, 0.7, 1.0); + environment_custom_energy = 1.0; bake_quality = BAKE_QUALITY_MEDIUM; - bake_mode = BAKE_MODE_CONE_TRACE; - energy = 1; - propagation = 1; - hdr = false; - image_path = "."; - set_disable_scale(true); + interior = false; + directional = false; + + gen_probes = GENERATE_PROBES_DISABLED; + use_denoiser = true; + bounces = 1; + bias = 0.0005; + max_texture_size = 16384; } -#endif diff --git a/scene/3d/baked_lightmap.h b/scene/3d/baked_lightmap.h index bc9e3f55ea..020d5fe1e0 100644 --- a/scene/3d/baked_lightmap.h +++ b/scene/3d/baked_lightmap.h @@ -28,189 +28,257 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 #ifndef BAKED_INDIRECT_LIGHT_H #define BAKED_INDIRECT_LIGHT_H -#include "multimesh_instance.h" -#include "scene/3d/light.h" -#include "scene/3d/visual_instance.h" +#include "core/local_vector.h" +#include "scene/3d/light_3d.h" +#include "scene/3d/lightmapper.h" +#include "scene/3d/mesh_instance_3d.h" +#include "scene/3d/multimesh_instance_3d.h" +#include "scene/3d/visual_instance_3d.h" +#include "scene/resources/sky.h" class BakedLightmapData : public Resource { GDCLASS(BakedLightmapData, Resource); + RES_BASE_EXTENSION("lmbake") - RID baked_light; + Ref light_texture; + + bool uses_spherical_harmonics = false; + bool interior = false; + + RID lightmap; AABB bounds; - float energy; - int cell_subdiv; - Transform cell_space_xform; struct User { NodePath path; - Ref lightmap; - int instance_index; + int32_t sub_instance; + Rect2 uv_scale; + int slice_index; }; Vector users; void _set_user_data(const Array &p_data); Array _get_user_data() const; + void _set_probe_data(const Dictionary &p_data); + Dictionary _get_probe_data() const; protected: static void _bind_methods(); public: - void set_bounds(const AABB &p_bounds); - AABB get_bounds() const; - - void set_octree(const Vector &p_octree); - Vector get_octree() const; - - void set_cell_space_transform(const Transform &p_xform); - Transform get_cell_space_transform() const; - - void set_cell_subdiv(int p_cell_subdiv); - int get_cell_subdiv() const; - - void set_energy(float p_energy); - float get_energy() const; - - void add_user(const NodePath &p_path, const Ref &p_lightmap, int p_instance = -1); + void add_user(const NodePath &p_path, const Rect2 &p_uv_scale, int p_slice_index, int32_t p_sub_instance = -1); int get_user_count() const; NodePath get_user_path(int p_user) const; - Ref get_user_lightmap(int p_user) const; - int get_user_instance(int p_user) const; + int32_t get_user_sub_instance(int p_user) const; + Rect2 get_user_lightmap_uv_scale(int p_user) const; + int get_user_lightmap_slice_index(int p_user) const; void clear_users(); + void set_light_texture(const Ref &p_light_texture); + Ref get_light_texture() const; + + void set_uses_spherical_harmonics(bool p_enable); + bool is_using_spherical_harmonics() const; + + bool is_interior() const; + + void set_capture_data(const AABB &p_bounds, bool p_interior, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree); + PackedVector3Array get_capture_points() const; + PackedColorArray get_capture_sh() const; + PackedInt32Array get_capture_tetrahedra() const; + PackedInt32Array get_capture_bsp_tree() const; + AABB get_capture_bounds() const; + + void clear(); + virtual RID get_rid() const; BakedLightmapData(); ~BakedLightmapData(); }; -class BakedLightmap : public VisualInstance { - GDCLASS(BakedLightmap, VisualInstance); +class BakedLightmap : public VisualInstance3D { + GDCLASS(BakedLightmap, VisualInstance3D); public: enum BakeQuality { BAKE_QUALITY_LOW, BAKE_QUALITY_MEDIUM, - BAKE_QUALITY_HIGH + BAKE_QUALITY_HIGH, + BAKE_QUALITY_ULTRA, }; - - enum BakeMode { - BAKE_MODE_CONE_TRACE, - BAKE_MODE_RAY_TRACE, + enum GenerateProbes { + GENERATE_PROBES_DISABLED, + GENERATE_PROBES_SUBDIV_4, + GENERATE_PROBES_SUBDIV_8, + GENERATE_PROBES_SUBDIV_16, + GENERATE_PROBES_SUBDIV_32, }; enum BakeError { BAKE_ERROR_OK, + BAKE_ERROR_NO_LIGHTMAPPER, BAKE_ERROR_NO_SAVE_PATH, BAKE_ERROR_NO_MESHES, + BAKE_ERROR_MESHES_INVALID, BAKE_ERROR_CANT_CREATE_IMAGE, BAKE_ERROR_USER_ABORTED }; - typedef void (*BakeBeginFunc)(int); - typedef bool (*BakeStepFunc)(int, const String &); - typedef void (*BakeEndFunc)(); + enum EnvironmentMode { + ENVIRONMENT_MODE_DISABLED, + ENVIRONMENT_MODE_SCENE, + ENVIRONMENT_MODE_CUSTOM_SKY, + ENVIRONMENT_MODE_CUSTOM_COLOR, + }; private: - float bake_cell_size; - float capture_cell_size; - Vector3 extents; - float bake_default_texels_per_unit; - float propagation; - float energy; BakeQuality bake_quality; - BakeMode bake_mode; - bool hdr; - String image_path; + bool use_denoiser; + int bounces; + float bias; + int max_texture_size; + bool interior; + EnvironmentMode environment_mode; + Ref environment_custom_sky; + Color environment_custom_color; + float environment_custom_energy; + bool directional; + GenerateProbes gen_probes; Ref light_data; - struct PlotMesh { - Ref override_material; - Vector > instance_materials; + struct LightsFound { + Transform xform; + Light3D *light; + }; + + struct MeshesFound { + Transform xform; + NodePath node_path; + int32_t subindex; Ref mesh; - Transform local_xform; - NodePath path; - int instance_idx; + int32_t lightmap_scale; + Vector> overrides; }; - struct PlotLight { - Light *light; - Transform local_xform; - }; - - void _find_meshes_and_lights(Node *p_at_node, List &plot_meshes, List &plot_lights); - - void _debug_bake(); + void _find_meshes_and_lights(Node *p_at_node, Vector &meshes, Vector &lights, Vector &probes); void _assign_lightmaps(); void _clear_lightmaps(); - static bool _bake_time(void *ud, float p_secs, float p_progress); - struct BakeTimeData { String text; int pass; uint64_t last_step; }; + struct BSPSimplex { + int vertices[4]; + int planes[4]; + }; + + struct BSPNode { + static const int32_t EMPTY_LEAF = INT32_MIN; + Plane plane; + int32_t over = EMPTY_LEAF, under = EMPTY_LEAF; + }; + + int _bsp_get_simplex_side(const Vector &p_points, const LocalVector &p_simplices, const Plane &p_plane, uint32_t p_simplex) const; + int32_t _compute_bsp_tree(const Vector &p_points, const LocalVector &p_planes, LocalVector &planes_tested, const LocalVector &p_simplices, const LocalVector &p_simplex_indices, LocalVector &bsp_nodes); + + struct BakeStepUD { + Lightmapper::BakeStepFunc func; + void *ud; + float from_percent; + float to_percent; + }; + + static bool _lightmap_bake_step_function(float p_completion, const String &p_text, void *ud, bool p_refresh); + + struct GenProbesOctree { + Vector3i offset; + uint32_t size; + GenProbesOctree *children[8] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + ~GenProbesOctree() { + for (int i = 0; i < 8; i++) { + if (children[i] != nullptr) { + memdelete(children[i]); + } + } + } + }; + + struct Vector3iHash { + _FORCE_INLINE_ static uint32_t hash(const Vector3i &p_vtx) { + uint32_t h = hash_djb2_one_32(p_vtx.x); + h = hash_djb2_one_32(p_vtx.y, h); + return hash_djb2_one_32(p_vtx.z, h); + } + }; + + void _plot_triangle_into_octree(GenProbesOctree *p_cell, float p_cell_size, const Vector3 *p_triangle); + void _gen_new_positions_from_octree(const GenProbesOctree *p_cell, float p_cell_size, const Vector &probe_positions, LocalVector &new_probe_positions, HashMap &positions_used, const AABB &p_bounds); + protected: + void _validate_property(PropertyInfo &property) const; static void _bind_methods(); void _notification(int p_what); public: - static BakeBeginFunc bake_begin_function; - static BakeStepFunc bake_step_function; - static BakeEndFunc bake_end_function; - void set_light_data(const Ref &p_data); Ref get_light_data() const; - void set_bake_cell_size(float p_cell_size); - float get_bake_cell_size() const; - - void set_capture_cell_size(float p_cell_size); - float get_capture_cell_size() const; - - void set_extents(const Vector3 &p_extents); - Vector3 get_extents() const; - - void set_bake_default_texels_per_unit(const float &p_bake_texels_per_unit); - float get_bake_default_texels_per_unit() const; - - void set_propagation(float p_propagation); - float get_propagation() const; - - void set_energy(float p_energy); - float get_energy() const; - void set_bake_quality(BakeQuality p_quality); BakeQuality get_bake_quality() const; - void set_bake_mode(BakeMode p_mode); - BakeMode get_bake_mode() const; + void set_use_denoiser(bool p_enable); + bool is_using_denoiser() const; - void set_hdr(bool p_enable); - bool is_hdr() const; + void set_directional(bool p_enable); + bool is_directional() const; - void set_image_path(const String &p_path); - String get_image_path() const; + void set_interior(bool p_interior); + bool is_interior() const; + + void set_environment_mode(EnvironmentMode p_mode); + EnvironmentMode get_environment_mode() const; + + void set_environment_custom_sky(const Ref &p_sky); + Ref get_environment_custom_sky() const; + + void set_environment_custom_color(const Color &p_color); + Color get_environment_custom_color() const; + + void set_environment_custom_energy(float p_energy); + float get_environment_custom_energy() const; + + void set_bounces(int p_bounces); + int get_bounces() const; + + void set_bias(float p_bias); + float get_bias() const; + + void set_max_texture_size(int p_size); + int get_max_texture_size() const; + + void set_generate_probes(GenerateProbes p_generate_probes); + GenerateProbes get_generate_probes() const; AABB get_aabb() const; Vector get_faces(uint32_t p_usage_flags) const; - BakeError bake(Node *p_from_node, bool p_create_visual_debug = false); + BakeError bake(Node *p_from_node, String p_image_data_path = "", Lightmapper::BakeStepFunc p_bake_step = nullptr, void *p_bake_userdata = nullptr); BakedLightmap(); }; +VARIANT_ENUM_CAST(BakedLightmap::GenerateProbes); VARIANT_ENUM_CAST(BakedLightmap::BakeQuality); -VARIANT_ENUM_CAST(BakedLightmap::BakeMode); VARIANT_ENUM_CAST(BakedLightmap::BakeError); +VARIANT_ENUM_CAST(BakedLightmap::EnvironmentMode); -#endif #endif // BAKED_INDIRECT_LIGHT_H diff --git a/scene/3d/gi_probe.cpp b/scene/3d/gi_probe.cpp index 6d571ee4f2..65a330ddc0 100644 --- a/scene/3d/gi_probe.cpp +++ b/scene/3d/gi_probe.cpp @@ -349,7 +349,7 @@ Vector3 GIProbe::get_extents() const { void GIProbe::_find_meshes(Node *p_at_node, List &plot_meshes) { MeshInstance3D *mi = Object::cast_to(p_at_node); - if (mi && mi->get_flag(GeometryInstance3D::FLAG_USE_BAKED_LIGHT) && mi->is_visible_in_tree()) { + if (mi && mi->get_gi_mode() == GeometryInstance3D::GI_MODE_BAKED && mi->is_visible_in_tree()) { Ref mesh = mi->get_mesh(); if (mesh.is_valid()) { diff --git a/scene/3d/lightmap_probe.cpp b/scene/3d/lightmap_probe.cpp new file mode 100644 index 0000000000..2da81337f0 --- /dev/null +++ b/scene/3d/lightmap_probe.cpp @@ -0,0 +1,4 @@ +#include "lightmap_probe.h" + +LightmapProbe::LightmapProbe() { +} diff --git a/scene/3d/lightmap_probe.h b/scene/3d/lightmap_probe.h new file mode 100644 index 0000000000..65bc6914f4 --- /dev/null +++ b/scene/3d/lightmap_probe.h @@ -0,0 +1,12 @@ +#ifndef LIGHTMAP_PROBE_H +#define LIGHTMAP_PROBE_H + +#include "scene/3d/node_3d.h" + +class LightmapProbe : public Node3D { + GDCLASS(LightmapProbe, Node3D) +public: + LightmapProbe(); +}; + +#endif // LIGHTMAP_PROBE_H diff --git a/scene/3d/lightmapper.cpp b/scene/3d/lightmapper.cpp new file mode 100644 index 0000000000..53ebd5ee2e --- /dev/null +++ b/scene/3d/lightmapper.cpp @@ -0,0 +1,37 @@ +#include "lightmapper.h" + +LightmapDenoiser *(*LightmapDenoiser::create_function)() = nullptr; + +Ref LightmapDenoiser::create() { + if (create_function) { + return Ref(create_function()); + } + return Ref(); +} + +Lightmapper::CreateFunc Lightmapper::create_custom = nullptr; +Lightmapper::CreateFunc Lightmapper::create_gpu = nullptr; +Lightmapper::CreateFunc Lightmapper::create_cpu = nullptr; + +Ref Lightmapper::create() { + Lightmapper *lm = nullptr; + if (create_custom) { + lm = create_custom(); + } + + if (!lm && create_gpu) { + lm = create_gpu(); + } + + if (!lm && create_cpu) { + lm = create_cpu(); + } + if (!lm) { + return Ref(); + } else { + return Ref(lm); + } +} + +Lightmapper::Lightmapper() { +} diff --git a/scene/3d/lightmapper.h b/scene/3d/lightmapper.h new file mode 100644 index 0000000000..7c052a30a0 --- /dev/null +++ b/scene/3d/lightmapper.h @@ -0,0 +1,90 @@ +#ifndef LIGHTMAPPER_H +#define LIGHTMAPPER_H + +#include "scene/resources/mesh.h" +#include "servers/rendering/rendering_device.h" + +class LightmapDenoiser : public Reference { + GDCLASS(LightmapDenoiser, Reference) +protected: + static LightmapDenoiser *(*create_function)(); + +public: + virtual Ref denoise_image(const Ref &p_image) = 0; + static Ref create(); +}; + +class Lightmapper : public Reference { + GDCLASS(Lightmapper, Reference) +public: + enum GenerateProbes { + GENERATE_PROBES_DISABLED, + GENERATE_PROBES_SUBDIV_4, + GENERATE_PROBES_SUBDIV_8, + GENERATE_PROBES_SUBDIV_16, + GENERATE_PROBES_SUBDIV_32, + + }; + + enum LightType { + LIGHT_TYPE_DIRECTIONAL, + LIGHT_TYPE_OMNI, + LIGHT_TYPE_SPOT + }; + + enum BakeError { + BAKE_ERROR_LIGHTMAP_TOO_SMALL, + BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES, + BAKE_OK + }; + + enum BakeQuality { + BAKE_QUALITY_LOW, + BAKE_QUALITY_MEDIUM, + BAKE_QUALITY_HIGH, + BAKE_QUALITY_ULTRA, + }; + + typedef Lightmapper *(*CreateFunc)(); + + static CreateFunc create_custom; + static CreateFunc create_gpu; + static CreateFunc create_cpu; + +protected: +public: + typedef bool (*BakeStepFunc)(float, const String &, void *, bool); //step index, step total, step description, userdata + + struct MeshData { + //triangle data + Vector points; + Vector uv2; + Vector normal; + Ref albedo_on_uv2; + Ref emission_on_uv2; + Variant userdata; + }; + + virtual void add_mesh(const MeshData &p_mesh) = 0; + virtual void add_directional_light(bool p_static, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_angular_distance) = 0; + virtual void add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size) = 0; + virtual void add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size) = 0; + virtual void add_probe(const Vector3 &p_position) = 0; + virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_step_userdata = nullptr) = 0; + + virtual int get_bake_texture_count() const = 0; + virtual Ref get_bake_texture(int p_index) const = 0; + virtual int get_bake_mesh_count() const = 0; + virtual Variant get_bake_mesh_userdata(int p_index) const = 0; + virtual Rect2 get_bake_mesh_uv_scale(int p_index) const = 0; + virtual int get_bake_mesh_texture_slice(int p_index) const = 0; + virtual int get_bake_probe_count() const = 0; + virtual Vector3 get_bake_probe_point(int p_probe) const = 0; + virtual Vector get_bake_probe_sh(int p_probe) const = 0; + + static Ref create(); + + Lightmapper(); +}; + +#endif // LIGHTMAPPER_H diff --git a/scene/3d/visual_instance_3d.cpp b/scene/3d/visual_instance_3d.cpp index ce8672c1dd..4724c88a30 100644 --- a/scene/3d/visual_instance_3d.cpp +++ b/scene/3d/visual_instance_3d.cpp @@ -239,7 +239,17 @@ bool GeometryInstance3D::_set(const StringName &p_name, const Variant &p_value) set_shader_instance_uniform(*r, p_value); return true; } +#ifndef DISABLE_DEPRECATED + if (p_name == SceneStringNames::get_singleton()->use_in_baked_light && bool(p_value)) { + set_gi_mode(GI_MODE_BAKED); + return true; + } + if (p_name == SceneStringNames::get_singleton()->use_dynamic_gi && bool(p_value)) { + set_gi_mode(GI_MODE_DYNAMIC); + return true; + } +#endif return false; } @@ -273,23 +283,6 @@ void GeometryInstance3D::_get_property_list(List *p_list) const { } } -void GeometryInstance3D::set_flag(Flags p_flag, bool p_value) { - - ERR_FAIL_INDEX(p_flag, FLAG_MAX); - if (flags[p_flag] == p_value) - return; - - flags[p_flag] = p_value; - RS::get_singleton()->instance_geometry_set_flag(get_instance(), (RS::InstanceFlags)p_flag, p_value); -} - -bool GeometryInstance3D::get_flag(Flags p_flag) const { - - ERR_FAIL_INDEX_V(p_flag, FLAG_MAX, false); - - return flags[p_flag]; -} - void GeometryInstance3D::set_cast_shadows_setting(ShadowCastingSetting p_shadow_casting_setting) { shadow_casting_setting = p_shadow_casting_setting; @@ -335,14 +328,45 @@ void GeometryInstance3D::set_custom_aabb(AABB aabb) { RS::get_singleton()->instance_set_custom_aabb(get_instance(), aabb); } +void GeometryInstance3D::set_lightmap_scale(LightmapScale p_scale) { + ERR_FAIL_INDEX(p_scale, LIGHTMAP_SCALE_MAX); + lightmap_scale = p_scale; +} + +GeometryInstance3D::LightmapScale GeometryInstance3D::get_lightmap_scale() const { + return lightmap_scale; +} + +void GeometryInstance3D::set_gi_mode(GIMode p_mode) { + + switch (p_mode) { + case GI_MODE_DISABLED: { + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_BAKED_LIGHT, false); + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_DYNAMIC_GI, false); + } break; + case GI_MODE_BAKED: { + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_BAKED_LIGHT, true); + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_DYNAMIC_GI, false); + + } break; + case GI_MODE_DYNAMIC: { + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_BAKED_LIGHT, false); + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_DYNAMIC_GI, true); + } break; + } + + gi_mode = p_mode; +} + +GeometryInstance3D::GIMode GeometryInstance3D::get_gi_mode() const { + return gi_mode; +} + void GeometryInstance3D::_bind_methods() { ClassDB::bind_method(D_METHOD("set_material_override", "material"), &GeometryInstance3D::set_material_override); ClassDB::bind_method(D_METHOD("get_material_override"), &GeometryInstance3D::get_material_override); - ClassDB::bind_method(D_METHOD("set_flag", "flag", "value"), &GeometryInstance3D::set_flag); - ClassDB::bind_method(D_METHOD("get_flag", "flag"), &GeometryInstance3D::get_flag); - ClassDB::bind_method(D_METHOD("set_cast_shadows_setting", "shadow_casting_setting"), &GeometryInstance3D::set_cast_shadows_setting); ClassDB::bind_method(D_METHOD("get_cast_shadows_setting"), &GeometryInstance3D::get_cast_shadows_setting); @@ -364,6 +388,12 @@ void GeometryInstance3D::_bind_methods() { ClassDB::bind_method(D_METHOD("set_extra_cull_margin", "margin"), &GeometryInstance3D::set_extra_cull_margin); ClassDB::bind_method(D_METHOD("get_extra_cull_margin"), &GeometryInstance3D::get_extra_cull_margin); + ClassDB::bind_method(D_METHOD("set_lightmap_scale", "scale"), &GeometryInstance3D::set_lightmap_scale); + ClassDB::bind_method(D_METHOD("get_lightmap_scale"), &GeometryInstance3D::get_lightmap_scale); + + ClassDB::bind_method(D_METHOD("set_gi_mode", "mode"), &GeometryInstance3D::set_gi_mode); + ClassDB::bind_method(D_METHOD("get_gi_mode"), &GeometryInstance3D::get_gi_mode); + ClassDB::bind_method(D_METHOD("set_custom_aabb", "aabb"), &GeometryInstance3D::set_custom_aabb); ClassDB::bind_method(D_METHOD("get_aabb"), &GeometryInstance3D::get_aabb); @@ -372,8 +402,9 @@ void GeometryInstance3D::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "material_override", PROPERTY_HINT_RESOURCE_TYPE, "ShaderMaterial,StandardMaterial3D", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_DEFERRED_SET_RESOURCE), "set_material_override", "get_material_override"); ADD_PROPERTY(PropertyInfo(Variant::INT, "cast_shadow", PROPERTY_HINT_ENUM, "Off,On,Double-Sided,Shadows Only"), "set_cast_shadows_setting", "get_cast_shadows_setting"); ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "extra_cull_margin", PROPERTY_HINT_RANGE, "0,16384,0.01"), "set_extra_cull_margin", "get_extra_cull_margin"); - ADD_PROPERTYI(PropertyInfo(Variant::BOOL, "use_in_baked_light"), "set_flag", "get_flag", FLAG_USE_BAKED_LIGHT); - ADD_PROPERTYI(PropertyInfo(Variant::BOOL, "use_dynamic_gi"), "set_flag", "get_flag", FLAG_USE_DYNAMIC_GI); + ADD_GROUP("Global Illumination", "gi_"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "gi_mode", PROPERTY_HINT_ENUM, "Disabled,Baked,Dynamic"), "set_gi_mode", "get_gi_mode"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "gi_lightmap_scale", PROPERTY_HINT_ENUM, "1x,2x,4x,8x"), "set_lightmap_scale", "get_lightmap_scale"); ADD_GROUP("LOD", "lod_"); ADD_PROPERTY(PropertyInfo(Variant::INT, "lod_min_distance", PROPERTY_HINT_RANGE, "0,32768,0.01"), "set_lod_min_distance", "get_lod_min_distance"); @@ -388,10 +419,15 @@ void GeometryInstance3D::_bind_methods() { BIND_ENUM_CONSTANT(SHADOW_CASTING_SETTING_DOUBLE_SIDED); BIND_ENUM_CONSTANT(SHADOW_CASTING_SETTING_SHADOWS_ONLY); - BIND_ENUM_CONSTANT(FLAG_USE_BAKED_LIGHT); - BIND_ENUM_CONSTANT(FLAG_USE_DYNAMIC_GI); - BIND_ENUM_CONSTANT(FLAG_DRAW_NEXT_FRAME_IF_VISIBLE); - BIND_ENUM_CONSTANT(FLAG_MAX); + BIND_ENUM_CONSTANT(GI_MODE_DISABLED); + BIND_ENUM_CONSTANT(GI_MODE_BAKED); + BIND_ENUM_CONSTANT(GI_MODE_DYNAMIC); + + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_1X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_2X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_4X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_8X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_MAX); } GeometryInstance3D::GeometryInstance3D() { @@ -400,9 +436,8 @@ GeometryInstance3D::GeometryInstance3D() { lod_min_hysteresis = 0; lod_max_hysteresis = 0; - for (int i = 0; i < FLAG_MAX; i++) { - flags[i] = false; - } + gi_mode = GI_MODE_DISABLED; + lightmap_scale = LIGHTMAP_SCALE_1X; shadow_casting_setting = SHADOW_CASTING_SETTING_ON; extra_cull_margin = 0; diff --git a/scene/3d/visual_instance_3d.h b/scene/3d/visual_instance_3d.h index cc5f92066f..a871c65b6a 100644 --- a/scene/3d/visual_instance_3d.h +++ b/scene/3d/visual_instance_3d.h @@ -85,13 +85,6 @@ class GeometryInstance3D : public VisualInstance3D { GDCLASS(GeometryInstance3D, VisualInstance3D); public: - enum Flags { - FLAG_USE_BAKED_LIGHT = RS::INSTANCE_FLAG_USE_BAKED_LIGHT, - FLAG_USE_DYNAMIC_GI = RS::INSTANCE_FLAG_USE_DYNAMIC_GI, - FLAG_DRAW_NEXT_FRAME_IF_VISIBLE = RS::INSTANCE_FLAG_DRAW_NEXT_FRAME_IF_VISIBLE, - FLAG_MAX = RS::INSTANCE_FLAG_MAX, - }; - enum ShadowCastingSetting { SHADOW_CASTING_SETTING_OFF = RS::SHADOW_CASTING_SETTING_OFF, SHADOW_CASTING_SETTING_ON = RS::SHADOW_CASTING_SETTING_ON, @@ -99,8 +92,21 @@ public: SHADOW_CASTING_SETTING_SHADOWS_ONLY = RS::SHADOW_CASTING_SETTING_SHADOWS_ONLY }; + enum GIMode { + GI_MODE_DISABLED, + GI_MODE_BAKED, + GI_MODE_DYNAMIC + }; + + enum LightmapScale { + LIGHTMAP_SCALE_1X, + LIGHTMAP_SCALE_2X, + LIGHTMAP_SCALE_4X, + LIGHTMAP_SCALE_8X, + LIGHTMAP_SCALE_MAX, + }; + private: - bool flags[FLAG_MAX]; ShadowCastingSetting shadow_casting_setting; Ref material_override; float lod_min_distance; @@ -112,6 +118,8 @@ private: mutable HashMap instance_uniform_property_remap; float extra_cull_margin; + LightmapScale lightmap_scale; + GIMode gi_mode; const StringName *_instance_uniform_get_remap(const StringName p_name) const; @@ -124,9 +132,6 @@ protected: static void _bind_methods(); public: - void set_flag(Flags p_flag, bool p_value); - bool get_flag(Flags p_flag) const; - void set_cast_shadows_setting(ShadowCastingSetting p_shadow_casting_setting); ShadowCastingSetting get_cast_shadows_setting() const; @@ -148,6 +153,12 @@ public: void set_extra_cull_margin(float p_margin); float get_extra_cull_margin() const; + void set_gi_mode(GIMode p_mode); + GIMode get_gi_mode() const; + + void set_lightmap_scale(LightmapScale p_scale); + LightmapScale get_lightmap_scale() const; + void set_shader_instance_uniform(const StringName &p_uniform, const Variant &p_value); Variant get_shader_instance_uniform(const StringName &p_uniform) const; @@ -156,7 +167,8 @@ public: GeometryInstance3D(); }; -VARIANT_ENUM_CAST(GeometryInstance3D::Flags); VARIANT_ENUM_CAST(GeometryInstance3D::ShadowCastingSetting); +VARIANT_ENUM_CAST(GeometryInstance3D::LightmapScale); +VARIANT_ENUM_CAST(GeometryInstance3D::GIMode); #endif diff --git a/scene/3d/voxelizer.cpp b/scene/3d/voxelizer.cpp index a2d305f3cb..f9c3810843 100644 --- a/scene/3d/voxelizer.cpp +++ b/scene/3d/voxelizer.cpp @@ -29,207 +29,12 @@ /*************************************************************************/ #include "voxelizer.h" +#include "core/math/geometry.h" #include "core/os/os.h" #include "core/os/threaded_array_processor.h" #include -#define FINDMINMAX(x0, x1, x2, min, max) \ - min = max = x0; \ - if (x1 < min) \ - min = x1; \ - if (x1 > max) \ - max = x1; \ - if (x2 < min) \ - min = x2; \ - if (x2 > max) \ - max = x2; - -static bool planeBoxOverlap(Vector3 normal, float d, Vector3 maxbox) { - int q; - Vector3 vmin, vmax; - for (q = 0; q <= 2; q++) { - if (normal[q] > 0.0f) { - vmin[q] = -maxbox[q]; - vmax[q] = maxbox[q]; - } else { - vmin[q] = maxbox[q]; - vmax[q] = -maxbox[q]; - } - } - if (normal.dot(vmin) + d > 0.0f) - return false; - if (normal.dot(vmax) + d >= 0.0f) - return true; - - return false; -} - -/*======================== X-tests ========================*/ -#define AXISTEST_X01(a, b, fa, fb) \ - p0 = a * v0.y - b * v0.z; \ - p2 = a * v2.y - b * v2.z; \ - if (p0 < p2) { \ - min = p0; \ - max = p2; \ - } else { \ - min = p2; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) \ - return false; - -#define AXISTEST_X2(a, b, fa, fb) \ - p0 = a * v0.y - b * v0.z; \ - p1 = a * v1.y - b * v1.z; \ - if (p0 < p1) { \ - min = p0; \ - max = p1; \ - } else { \ - min = p1; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) \ - return false; - -/*======================== Y-tests ========================*/ -#define AXISTEST_Y02(a, b, fa, fb) \ - p0 = -a * v0.x + b * v0.z; \ - p2 = -a * v2.x + b * v2.z; \ - if (p0 < p2) { \ - min = p0; \ - max = p2; \ - } else { \ - min = p2; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) \ - return false; - -#define AXISTEST_Y1(a, b, fa, fb) \ - p0 = -a * v0.x + b * v0.z; \ - p1 = -a * v1.x + b * v1.z; \ - if (p0 < p1) { \ - min = p0; \ - max = p1; \ - } else { \ - min = p1; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) \ - return false; - -/*======================== Z-tests ========================*/ - -#define AXISTEST_Z12(a, b, fa, fb) \ - p1 = a * v1.x - b * v1.y; \ - p2 = a * v2.x - b * v2.y; \ - if (p2 < p1) { \ - min = p2; \ - max = p1; \ - } else { \ - min = p1; \ - max = p2; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ - if (min > rad || max < -rad) \ - return false; - -#define AXISTEST_Z0(a, b, fa, fb) \ - p0 = a * v0.x - b * v0.y; \ - p1 = a * v1.x - b * v1.y; \ - if (p0 < p1) { \ - min = p0; \ - max = p1; \ - } else { \ - min = p1; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ - if (min > rad || max < -rad) \ - return false; - -static bool fast_tri_box_overlap(const Vector3 &boxcenter, const Vector3 boxhalfsize, const Vector3 *triverts) { - - /* use separating axis theorem to test overlap between triangle and box */ - /* need to test for overlap in these directions: */ - /* 1) the {x,y,z}-directions (actually, since we use the AABB of the triangle */ - /* we do not even need to test these) */ - /* 2) normal of the triangle */ - /* 3) crossproduct(edge from tri, {x,y,z}-directin) */ - /* this gives 3x3=9 more tests */ - Vector3 v0, v1, v2; - float min, max, d, p0, p1, p2, rad, fex, fey, fez; - Vector3 normal, e0, e1, e2; - - /* This is the fastest branch on Sun */ - /* move everything so that the boxcenter is in (0,0,0) */ - - v0 = triverts[0] - boxcenter; - v1 = triverts[1] - boxcenter; - v2 = triverts[2] - boxcenter; - - /* compute triangle edges */ - e0 = v1 - v0; /* tri edge 0 */ - e1 = v2 - v1; /* tri edge 1 */ - e2 = v0 - v2; /* tri edge 2 */ - - /* Bullet 3: */ - /* test the 9 tests first (this was faster) */ - fex = Math::abs(e0.x); - fey = Math::abs(e0.y); - fez = Math::abs(e0.z); - AXISTEST_X01(e0.z, e0.y, fez, fey); - AXISTEST_Y02(e0.z, e0.x, fez, fex); - AXISTEST_Z12(e0.y, e0.x, fey, fex); - - fex = Math::abs(e1.x); - fey = Math::abs(e1.y); - fez = Math::abs(e1.z); - AXISTEST_X01(e1.z, e1.y, fez, fey); - AXISTEST_Y02(e1.z, e1.x, fez, fex); - AXISTEST_Z0(e1.y, e1.x, fey, fex); - - fex = Math::abs(e2.x); - fey = Math::abs(e2.y); - fez = Math::abs(e2.z); - AXISTEST_X2(e2.z, e2.y, fez, fey); - AXISTEST_Y1(e2.z, e2.x, fez, fex); - AXISTEST_Z12(e2.y, e2.x, fey, fex); - - /* Bullet 1: */ - /* first test overlap in the {x,y,z}-directions */ - /* find min, max of the triangle each direction, and test for overlap in */ - /* that direction -- this is equivalent to testing a minimal AABB around */ - /* the triangle against the AABB */ - - /* test in X-direction */ - FINDMINMAX(v0.x, v1.x, v2.x, min, max); - if (min > boxhalfsize.x || max < -boxhalfsize.x) - return false; - - /* test in Y-direction */ - FINDMINMAX(v0.y, v1.y, v2.y, min, max); - if (min > boxhalfsize.y || max < -boxhalfsize.y) - return false; - - /* test in Z-direction */ - FINDMINMAX(v0.z, v1.z, v2.z, min, max); - if (min > boxhalfsize.z || max < -boxhalfsize.z) - return false; - - /* Bullet 2: */ - /* test if the box intersects the plane of the triangle */ - /* compute plane equation of triangle: normal*x+d=0 */ - normal = e0.cross(e1); - d = -normal.dot(v0); /* plane eq: normal.x+d=0 */ - return planeBoxOverlap(normal, d, boxhalfsize); /* if true, box and triangle overlaps */ -} - static _FORCE_INLINE_ void get_uv_and_normal(const Vector3 &p_pos, const Vector3 *p_vtx, const Vector2 *p_uv, const Vector3 *p_normal, Vector2 &r_uv, Vector3 &r_normal) { if (p_pos.distance_squared_to(p_vtx[0]) < CMP_EPSILON2) { @@ -324,7 +129,7 @@ void Voxelizer::_plot_face(int p_idx, int p_level, int p_x, int p_y, int p_z, co Vector3 half = (to - from) * 0.5; //is in this cell? - if (!fast_tri_box_overlap(from + half, half, p_vtx)) { + if (!Geometry::triangle_box_overlap(from + half, half, p_vtx)) { continue; //face does not span this cell } @@ -467,7 +272,7 @@ void Voxelizer::_plot_face(int p_idx, int p_level, int p_x, int p_y, int p_z, co //test_aabb.grow_by(test_aabb.get_longest_axis_size()*0.05); //grow a bit to avoid numerical error in real-time Vector3 qsize = test_aabb.size * 0.5; //quarter size, for fast aabb test - if (!fast_tri_box_overlap(test_aabb.position + qsize, qsize, p_vtx)) { + if (!Geometry::triangle_box_overlap(test_aabb.position + qsize, qsize, p_vtx)) { //if (!Face3(p_vtx[0],p_vtx[1],p_vtx[2]).intersects_aabb2(aabb)) { //does not fit in child, go on continue; @@ -648,7 +453,7 @@ void Voxelizer::plot_mesh(const Transform &p_xform, Ref &p_mesh, const Vec } //test against original bounds - if (!fast_tri_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) + if (!Geometry::triangle_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) continue; //plot _plot_face(0, 0, 0, 0, 0, vtxs, normal, uvs, material, po2_bounds); @@ -681,7 +486,7 @@ void Voxelizer::plot_mesh(const Transform &p_xform, Ref &p_mesh, const Vec } //test against original bounds - if (!fast_tri_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) + if (!Geometry::triangle_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) continue; //plot face _plot_face(0, 0, 0, 0, 0, vtxs, normal, uvs, material, po2_bounds); diff --git a/scene/register_scene_types.cpp b/scene/register_scene_types.cpp index dc3ef5b508..684df728b8 100644 --- a/scene/register_scene_types.cpp +++ b/scene/register_scene_types.cpp @@ -193,6 +193,7 @@ #include "scene/3d/gpu_particles_3d.h" #include "scene/3d/immediate_geometry_3d.h" #include "scene/3d/light_3d.h" +#include "scene/3d/lightmap_probe.h" #include "scene/3d/listener_3d.h" #include "scene/3d/mesh_instance_3d.h" #include "scene/3d/multimesh_instance_3d.h" @@ -225,8 +226,8 @@ static Ref resource_loader_text; static Ref resource_loader_dynamic_font; -static Ref resource_loader_stream_texture; -static Ref resource_loader_texture_layered; +static Ref resource_loader_stream_texture; +static Ref resource_loader_texture_layered; static Ref resource_loader_bmfont; @@ -432,8 +433,9 @@ void register_scene_types() { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); - //ClassDB::register_class(); - //ClassDB::register_class(); + ClassDB::register_class(); + ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); @@ -675,7 +677,7 @@ void register_scene_types() { ClassDB::register_virtual_class(); ClassDB::register_virtual_class(); ClassDB::register_class(); - ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); @@ -689,6 +691,11 @@ void register_scene_types() { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_virtual_class(); + ClassDB::register_class(); + ClassDB::register_class(); + ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_virtual_class(); ClassDB::register_class(); @@ -867,6 +874,7 @@ void register_scene_types() { ClassDB::add_compatibility_class("VisualShaderNodeScalarOp", "VisualShaderNodeFloatOp"); ClassDB::add_compatibility_class("VisualShaderNodeScalarUniform", "VisualShaderNodeFloatUniform"); ClassDB::add_compatibility_class("World", "World3D"); + ClassDB::add_compatibility_class("StreamTexture", "StreamTexture2D"); #endif diff --git a/scene/resources/mesh.cpp b/scene/resources/mesh.cpp index 401b689145..6548c65cd7 100644 --- a/scene/resources/mesh.cpp +++ b/scene/resources/mesh.cpp @@ -472,11 +472,11 @@ Ref Mesh::create_outline(float p_margin) const { return newmesh; } -void Mesh::set_lightmap_size_hint(const Vector2 &p_size) { +void Mesh::set_lightmap_size_hint(const Size2i &p_size) { lightmap_size_hint = p_size; } -Size2 Mesh::get_lightmap_size_hint() const { +Size2i Mesh::get_lightmap_size_hint() const { return lightmap_size_hint; } @@ -486,7 +486,7 @@ void Mesh::_bind_methods() { ClassDB::bind_method(D_METHOD("get_lightmap_size_hint"), &Mesh::get_lightmap_size_hint); ClassDB::bind_method(D_METHOD("get_aabb"), &Mesh::get_aabb); - ADD_PROPERTY(PropertyInfo(Variant::VECTOR2, "lightmap_size_hint"), "set_lightmap_size_hint", "get_lightmap_size_hint"); + ADD_PROPERTY(PropertyInfo(Variant::VECTOR2I, "lightmap_size_hint"), "set_lightmap_size_hint", "get_lightmap_size_hint"); ClassDB::bind_method(D_METHOD("get_surface_count"), &Mesh::get_surface_count); ClassDB::bind_method(D_METHOD("surface_get_arrays", "surf_idx"), &Mesh::surface_get_arrays); diff --git a/scene/resources/mesh.h b/scene/resources/mesh.h index a65cf0a928..80cd57846b 100644 --- a/scene/resources/mesh.h +++ b/scene/resources/mesh.h @@ -43,7 +43,7 @@ class Mesh : public Resource { mutable Ref triangle_mesh; //cached mutable Vector debug_lines; - Size2 lightmap_size_hint; + Size2i lightmap_size_hint; protected: static void _bind_methods(); @@ -138,8 +138,8 @@ public: virtual AABB get_aabb() const = 0; - void set_lightmap_size_hint(const Vector2 &p_size); - Size2 get_lightmap_size_hint() const; + void set_lightmap_size_hint(const Size2i &p_size); + Size2i get_lightmap_size_hint() const; void clear_cache() const; typedef Vector> (*ConvexDecompositionFunc)(const Vector &); diff --git a/scene/resources/texture.cpp b/scene/resources/texture.cpp index 91c40d871d..6e155ddf91 100644 --- a/scene/resources/texture.cpp +++ b/scene/resources/texture.cpp @@ -355,7 +355,7 @@ ImageTexture::~ImageTexture() { ////////////////////////////////////////// -Ref StreamTexture::load_image_from_file(FileAccess *f, int p_size_limit) { +Ref StreamTexture2D::load_image_from_file(FileAccess *f, int p_size_limit) { uint32_t data_format = f->get_32(); uint32_t w = f->get_16(); @@ -492,7 +492,7 @@ Ref StreamTexture::load_image_from_file(FileAccess *f, int p_size_limit) return Ref(); } -void StreamTexture::set_path(const String &p_path, bool p_take_over) { +void StreamTexture2D::set_path(const String &p_path, bool p_take_over) { if (texture.is_valid()) { RenderingServer::get_singleton()->texture_set_path(texture, p_path); @@ -501,40 +501,40 @@ void StreamTexture::set_path(const String &p_path, bool p_take_over) { Resource::set_path(p_path, p_take_over); } -void StreamTexture::_requested_3d(void *p_ud) { +void StreamTexture2D::_requested_3d(void *p_ud) { - StreamTexture *st = (StreamTexture *)p_ud; - Ref stex(st); + StreamTexture2D *st = (StreamTexture2D *)p_ud; + Ref stex(st); ERR_FAIL_COND(!request_3d_callback); request_3d_callback(stex); } -void StreamTexture::_requested_roughness(void *p_ud, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel) { +void StreamTexture2D::_requested_roughness(void *p_ud, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel) { - StreamTexture *st = (StreamTexture *)p_ud; - Ref stex(st); + StreamTexture2D *st = (StreamTexture2D *)p_ud; + Ref stex(st); ERR_FAIL_COND(!request_roughness_callback); request_roughness_callback(stex, p_normal_path, p_roughness_channel); } -void StreamTexture::_requested_normal(void *p_ud) { +void StreamTexture2D::_requested_normal(void *p_ud) { - StreamTexture *st = (StreamTexture *)p_ud; - Ref stex(st); + StreamTexture2D *st = (StreamTexture2D *)p_ud; + Ref stex(st); ERR_FAIL_COND(!request_normal_callback); request_normal_callback(stex); } -StreamTexture::TextureFormatRequestCallback StreamTexture::request_3d_callback = nullptr; -StreamTexture::TextureFormatRoughnessRequestCallback StreamTexture::request_roughness_callback = nullptr; -StreamTexture::TextureFormatRequestCallback StreamTexture::request_normal_callback = nullptr; +StreamTexture2D::TextureFormatRequestCallback StreamTexture2D::request_3d_callback = nullptr; +StreamTexture2D::TextureFormatRoughnessRequestCallback StreamTexture2D::request_roughness_callback = nullptr; +StreamTexture2D::TextureFormatRequestCallback StreamTexture2D::request_normal_callback = nullptr; -Image::Format StreamTexture::get_format() const { +Image::Format StreamTexture2D::get_format() const { return format; } -Error StreamTexture::_load_data(const String &p_path, int &tw, int &th, int &tw_custom, int &th_custom, Ref &image, bool &r_request_3d, bool &r_request_normal, bool &r_request_roughness, int &mipmap_limit, int p_size_limit) { +Error StreamTexture2D::_load_data(const String &p_path, int &tw, int &th, int &tw_custom, int &th_custom, Ref &image, bool &r_request_3d, bool &r_request_normal, bool &r_request_roughness, int &mipmap_limit, int p_size_limit) { alpha_cache.unref(); @@ -595,7 +595,7 @@ Error StreamTexture::_load_data(const String &p_path, int &tw, int &th, int &tw_ return OK; } -Error StreamTexture::load(const String &p_path) { +Error StreamTexture2D::load(const String &p_path) { int lw, lh, lwc, lhc; Ref image; @@ -661,20 +661,20 @@ Error StreamTexture::load(const String &p_path) { emit_changed(); return OK; } -String StreamTexture::get_load_path() const { +String StreamTexture2D::get_load_path() const { return path_to_file; } -int StreamTexture::get_width() const { +int StreamTexture2D::get_width() const { return w; } -int StreamTexture::get_height() const { +int StreamTexture2D::get_height() const { return h; } -RID StreamTexture::get_rid() const { +RID StreamTexture2D::get_rid() const { if (!texture.is_valid()) { texture = RS::get_singleton()->texture_2d_placeholder_create(); @@ -682,7 +682,7 @@ RID StreamTexture::get_rid() const { return texture; } -void StreamTexture::draw(RID p_canvas_item, const Point2 &p_pos, const Color &p_modulate, bool p_transpose, const Ref &p_normal_map, const Ref &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { +void StreamTexture2D::draw(RID p_canvas_item, const Point2 &p_pos, const Color &p_modulate, bool p_transpose, const Ref &p_normal_map, const Ref &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { if ((w | h) == 0) return; @@ -690,7 +690,7 @@ void StreamTexture::draw(RID p_canvas_item, const Point2 &p_pos, const Color &p_ RID specular_rid = p_specular_map.is_valid() ? p_specular_map->get_rid() : RID(); RenderingServer::get_singleton()->canvas_item_add_texture_rect(p_canvas_item, Rect2(p_pos, Size2(w, h)), texture, false, p_modulate, p_transpose, normal_rid, specular_rid, p_specular_color_shininess, p_texture_filter, p_texture_repeat); } -void StreamTexture::draw_rect(RID p_canvas_item, const Rect2 &p_rect, bool p_tile, const Color &p_modulate, bool p_transpose, const Ref &p_normal_map, const Ref &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { +void StreamTexture2D::draw_rect(RID p_canvas_item, const Rect2 &p_rect, bool p_tile, const Color &p_modulate, bool p_transpose, const Ref &p_normal_map, const Ref &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { if ((w | h) == 0) return; @@ -698,7 +698,7 @@ void StreamTexture::draw_rect(RID p_canvas_item, const Rect2 &p_rect, bool p_til RID specular_rid = p_specular_map.is_valid() ? p_specular_map->get_rid() : RID(); RenderingServer::get_singleton()->canvas_item_add_texture_rect(p_canvas_item, p_rect, texture, p_tile, p_modulate, p_transpose, normal_rid, specular_rid, p_specular_color_shininess, p_texture_filter, p_texture_repeat); } -void StreamTexture::draw_rect_region(RID p_canvas_item, const Rect2 &p_rect, const Rect2 &p_src_rect, const Color &p_modulate, bool p_transpose, const Ref &p_normal_map, const Ref &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat, bool p_clip_uv) const { +void StreamTexture2D::draw_rect_region(RID p_canvas_item, const Rect2 &p_rect, const Rect2 &p_src_rect, const Color &p_modulate, bool p_transpose, const Ref &p_normal_map, const Ref &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat, bool p_clip_uv) const { if ((w | h) == 0) return; @@ -707,12 +707,12 @@ void StreamTexture::draw_rect_region(RID p_canvas_item, const Rect2 &p_rect, con RenderingServer::get_singleton()->canvas_item_add_texture_rect_region(p_canvas_item, p_rect, texture, p_src_rect, p_modulate, p_transpose, normal_rid, specular_rid, p_specular_color_shininess, p_clip_uv, p_texture_filter, p_texture_repeat); } -bool StreamTexture::has_alpha() const { +bool StreamTexture2D::has_alpha() const { return false; } -Ref StreamTexture::get_data() const { +Ref StreamTexture2D::get_data() const { if (texture.is_valid()) { return RS::get_singleton()->texture_2d_get(texture); @@ -721,7 +721,7 @@ Ref StreamTexture::get_data() const { } } -bool StreamTexture::is_pixel_opaque(int p_x, int p_y) const { +bool StreamTexture2D::is_pixel_opaque(int p_x, int p_y) const { if (!alpha_cache.is_valid()) { Ref img = get_data(); @@ -757,7 +757,7 @@ bool StreamTexture::is_pixel_opaque(int p_x, int p_y) const { return true; } -void StreamTexture::reload_from_file() { +void StreamTexture2D::reload_from_file() { String path = get_path(); if (!path.is_resource_file()) @@ -771,34 +771,34 @@ void StreamTexture::reload_from_file() { load(path); } -void StreamTexture::_validate_property(PropertyInfo &property) const { +void StreamTexture2D::_validate_property(PropertyInfo &property) const { } -void StreamTexture::_bind_methods() { +void StreamTexture2D::_bind_methods() { - ClassDB::bind_method(D_METHOD("load", "path"), &StreamTexture::load); - ClassDB::bind_method(D_METHOD("get_load_path"), &StreamTexture::get_load_path); + ClassDB::bind_method(D_METHOD("load", "path"), &StreamTexture2D::load); + ClassDB::bind_method(D_METHOD("get_load_path"), &StreamTexture2D::get_load_path); ADD_PROPERTY(PropertyInfo(Variant::STRING, "load_path", PROPERTY_HINT_FILE, "*.stex"), "load", "get_load_path"); } -StreamTexture::StreamTexture() { +StreamTexture2D::StreamTexture2D() { format = Image::FORMAT_MAX; w = 0; h = 0; } -StreamTexture::~StreamTexture() { +StreamTexture2D::~StreamTexture2D() { if (texture.is_valid()) { RS::get_singleton()->free(texture); } } -RES ResourceFormatLoaderStreamTexture::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { +RES ResourceFormatLoaderStreamTexture2D::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { - Ref st; + Ref st; st.instance(); Error err = st->load(p_path); if (r_error) @@ -809,17 +809,17 @@ RES ResourceFormatLoaderStreamTexture::load(const String &p_path, const String & return st; } -void ResourceFormatLoaderStreamTexture::get_recognized_extensions(List *p_extensions) const { +void ResourceFormatLoaderStreamTexture2D::get_recognized_extensions(List *p_extensions) const { p_extensions->push_back("stex"); } -bool ResourceFormatLoaderStreamTexture::handles_type(const String &p_type) const { - return p_type == "StreamTexture"; +bool ResourceFormatLoaderStreamTexture2D::handles_type(const String &p_type) const { + return p_type == "StreamTexture2D"; } -String ResourceFormatLoaderStreamTexture::get_resource_type(const String &p_path) const { +String ResourceFormatLoaderStreamTexture2D::get_resource_type(const String &p_path) const { if (p_path.get_extension().to_lower() == "stex") - return "StreamTexture"; + return "StreamTexture2D"; return ""; } @@ -1930,23 +1930,47 @@ AnimatedTexture::~AnimatedTexture() { } /////////////////////////////// -Image::Format TextureLayered::get_format() const { +void TextureLayered::_bind_methods() { + + ClassDB::bind_method(D_METHOD("get_format"), &TextureLayered::get_format); + ClassDB::bind_method(D_METHOD("get_layered_type"), &TextureLayered::get_layered_type); + ClassDB::bind_method(D_METHOD("get_width"), &TextureLayered::get_width); + ClassDB::bind_method(D_METHOD("get_height"), &TextureLayered::get_height); + ClassDB::bind_method(D_METHOD("get_layers"), &TextureLayered::get_layers); + ClassDB::bind_method(D_METHOD("has_mipmaps"), &TextureLayered::has_mipmaps); + ClassDB::bind_method(D_METHOD("get_layer_data"), &TextureLayered::get_layer_data); + + BIND_ENUM_CONSTANT(LAYERED_TYPE_2D_ARRAY); + BIND_ENUM_CONSTANT(LAYERED_TYPE_CUBEMAP); + BIND_ENUM_CONSTANT(LAYERED_TYPE_CUBEMAP_ARRAY); +} + +/////////////////////////////// +Image::Format ImageTextureLayered::get_format() const { return format; } -uint32_t TextureLayered::get_width() const { +int ImageTextureLayered::get_width() const { return width; } -uint32_t TextureLayered::get_height() const { +int ImageTextureLayered::get_height() const { return height; } -uint32_t TextureLayered::get_layers() const { +int ImageTextureLayered::get_layers() const { return layers; } -Error TextureLayered::_create_from_images(const Array &p_images) { +bool ImageTextureLayered::has_mipmaps() const { + return mipmaps; +} + +ImageTextureLayered::LayeredType ImageTextureLayered::get_layered_type() const { + return layered_type; +} + +Error ImageTextureLayered::_create_from_images(const Array &p_images) { Vector> images; for (int i = 0; i < p_images.size(); i++) { Ref img = p_images[i]; @@ -1957,7 +1981,7 @@ Error TextureLayered::_create_from_images(const Array &p_images) { return create_from_images(images); } -Array TextureLayered::_get_images() const { +Array ImageTextureLayered::_get_images() const { Array images; for (int i = 0; i < layers; i++) { images.push_back(get_layer_data(i)); @@ -1965,14 +1989,14 @@ Array TextureLayered::_get_images() const { return images; } -Error TextureLayered::create_from_images(Vector> p_images) { +Error ImageTextureLayered::create_from_images(Vector> p_images) { int new_layers = p_images.size(); ERR_FAIL_COND_V(new_layers == 0, ERR_INVALID_PARAMETER); - if (layered_type == RS::TEXTURE_LAYERED_CUBEMAP) { + if (layered_type == LAYERED_TYPE_CUBEMAP) { ERR_FAIL_COND_V_MSG(new_layers != 6, ERR_INVALID_PARAMETER, "Cubemaps require exactly 6 layers"); - } else if (layered_type == RS::TEXTURE_LAYERED_CUBEMAP_ARRAY) { + } else if (layered_type == LAYERED_TYPE_CUBEMAP_ARRAY) { ERR_FAIL_COND_V_MSG((new_layers % 6) != 0, ERR_INVALID_PARAMETER, "Cubemap array layers must be a multiple of 6"); } @@ -1994,11 +2018,11 @@ Error TextureLayered::create_from_images(Vector> p_images) { } if (texture.is_valid()) { - RID new_texture = RS::get_singleton()->texture_2d_layered_create(p_images, layered_type); + RID new_texture = RS::get_singleton()->texture_2d_layered_create(p_images, RS::TextureLayeredType(layered_type)); ERR_FAIL_COND_V(!new_texture.is_valid(), ERR_CANT_CREATE); RS::get_singleton()->texture_replace(texture, new_texture); } else { - texture = RS::get_singleton()->texture_2d_layered_create(p_images, layered_type); + texture = RS::get_singleton()->texture_2d_layered_create(p_images, RS::TextureLayeredType(layered_type)); ERR_FAIL_COND_V(!texture.is_valid(), ERR_CANT_CREATE); } @@ -2010,7 +2034,7 @@ Error TextureLayered::create_from_images(Vector> p_images) { return OK; } -void TextureLayered::update_layer(const Ref &p_image, int p_layer) { +void ImageTextureLayered::update_layer(const Ref &p_image, int p_layer) { ERR_FAIL_COND(texture.is_valid()); ERR_FAIL_COND(p_image.is_null()); ERR_FAIL_COND(p_image->get_format() != format); @@ -2020,19 +2044,19 @@ void TextureLayered::update_layer(const Ref &p_image, int p_layer) { RS::get_singleton()->texture_2d_update(texture, p_image, p_layer); } -Ref TextureLayered::get_layer_data(int p_layer) const { +Ref ImageTextureLayered::get_layer_data(int p_layer) const { ERR_FAIL_INDEX_V(p_layer, layers, Ref()); return RS::get_singleton()->texture_2d_layer_get(texture, p_layer); } -RID TextureLayered::get_rid() const { +RID ImageTextureLayered::get_rid() const { if (texture.is_null()) { - texture = RS::get_singleton()->texture_2d_layered_placeholder_create(); + texture = RS::get_singleton()->texture_2d_layered_placeholder_create(RS::TextureLayeredType(layered_type)); } return texture; } -void TextureLayered::set_path(const String &p_path, bool p_take_over) { +void ImageTextureLayered::set_path(const String &p_path, bool p_take_over) { if (texture.is_valid()) { RS::get_singleton()->texture_set_path(texture, p_path); } @@ -2040,24 +2064,17 @@ void TextureLayered::set_path(const String &p_path, bool p_take_over) { Resource::set_path(p_path, p_take_over); } -void TextureLayered::_bind_methods() { +void ImageTextureLayered::_bind_methods() { - ClassDB::bind_method(D_METHOD("get_format"), &TextureLayered::get_format); + ClassDB::bind_method(D_METHOD("create_from_images", "images"), &ImageTextureLayered::_create_from_images); + ClassDB::bind_method(D_METHOD("update_layer", "image", "layer"), &ImageTextureLayered::update_layer); - ClassDB::bind_method(D_METHOD("get_width"), &TextureLayered::get_width); - ClassDB::bind_method(D_METHOD("get_height"), &TextureLayered::get_height); - ClassDB::bind_method(D_METHOD("get_layers"), &TextureLayered::get_layers); - - ClassDB::bind_method(D_METHOD("create_from_images", "images"), &TextureLayered::_create_from_images); - ClassDB::bind_method(D_METHOD("update_layer", "image", "layer"), &TextureLayered::update_layer); - ClassDB::bind_method(D_METHOD("get_layer_data", "layer"), &TextureLayered::get_layer_data); - - ClassDB::bind_method(D_METHOD("_get_images"), &TextureLayered::_get_images); + ClassDB::bind_method(D_METHOD("_get_images"), &ImageTextureLayered::_get_images); ADD_PROPERTY(PropertyInfo(Variant::ARRAY, "_images", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_INTERNAL), "create_from_images", "_get_images"); } -TextureLayered::TextureLayered(RenderingServer::TextureLayeredType p_layered_type) { +ImageTextureLayered::ImageTextureLayered(LayeredType p_layered_type) { layered_type = p_layered_type; format = Image::FORMAT_MAX; @@ -2066,193 +2083,241 @@ TextureLayered::TextureLayered(RenderingServer::TextureLayeredType p_layered_typ layers = 0; } -TextureLayered::~TextureLayered() { +ImageTextureLayered::~ImageTextureLayered() { if (texture.is_valid()) { RS::get_singleton()->free(texture); } } -RES ResourceFormatLoaderTextureLayered::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { +/////////////////////////////////////////// - if (r_error) { - *r_error = ERR_CANT_OPEN; +void StreamTextureLayered::set_path(const String &p_path, bool p_take_over) { + + if (texture.is_valid()) { + RenderingServer::get_singleton()->texture_set_path(texture, p_path); } - Ref lt; + Resource::set_path(p_path, p_take_over); +} - if (p_path.ends_with("cube")) { - Ref cube; - cube.instance(); - lt = cube; - } else if (p_path.ends_with("cubearr")) { - Ref cubearr; - cubearr.instance(); - lt = cubearr; - } else if (p_path.ends_with("tex2darr")) { - Ref t2darr; - t2darr.instance(); - lt = t2darr; - } else { - ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture extension."); +Image::Format StreamTextureLayered::get_format() const { + + return format; +} + +Error StreamTextureLayered::_load_data(const String &p_path, Vector> &images, int &mipmap_limit, int p_size_limit) { + + ERR_FAIL_COND_V(images.size() != 0, ERR_INVALID_PARAMETER); + + FileAccessRef f = FileAccess::open(p_path, FileAccess::READ); + ERR_FAIL_COND_V(!f, ERR_CANT_OPEN); + + uint8_t header[4]; + f->get_buffer(header, 4); + if (header[0] != 'G' || header[1] != 'S' || header[2] != 'T' || header[3] != 'L') { + ERR_FAIL_V_MSG(ERR_FILE_CORRUPT, "Stream texture layered file is corrupt (Bad header)."); } - FileAccess *f = FileAccess::open(p_path, FileAccess::READ); - ERR_FAIL_COND_V_MSG(!f, RES(), "Cannot open file '" + p_path + "'."); + uint32_t version = f->get_32(); - char header[5] = { 0, 0, 0, 0, 0 }; - f->get_buffer((uint8_t *)header, 4); - - if (String(header) != "GDLT") { - f->close(); - memdelete(f); - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - // FIXME: It's bogus that we fail in both branches. Seen while rebasing - // vulkan branch on master branch. - ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture."); - } else { - - f->close(); - memdelete(f); - ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture file format '" + String((const char *)header) + "'."); + if (version > FORMAT_VERSION) { + ERR_FAIL_V_MSG(ERR_FILE_CORRUPT, "Stream texture file is too new."); } - int tw = f->get_32(); - int th = f->get_32(); - int td = f->get_32(); - bool use_mipmaps = f->get_32() != 0; //texture flags (deprecated) - Image::Format format = Image::Format(f->get_32()); - uint32_t compression = f->get_32(); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed + uint32_t layer_count = f->get_32(); //layer count + uint32_t type = f->get_32(); //layer count + ERR_FAIL_COND_V(type != layered_type, ERR_INVALID_DATA); + + uint32_t df = f->get_32(); //data format + mipmap_limit = int(f->get_32()); + //reserverd + f->get_32(); + f->get_32(); + f->get_32(); + + if (!(df & FORMAT_BIT_STREAM)) { + p_size_limit = 0; + } + + images.resize(layer_count); + + for (uint32_t i = 0; i < layer_count; i++) { + Ref image = StreamTexture2D::load_image_from_file(f, p_size_limit); + ERR_FAIL_COND_V(image.is_null() || image->empty(), ERR_CANT_OPEN); + images.write[i] = image; + } + + return OK; +} + +Error StreamTextureLayered::load(const String &p_path) { Vector> images; - for (int layer = 0; layer < td; layer++) { - Ref image; - image.instance(); + int mipmap_limit; - if (compression == COMPRESSION_LOSSLESS) { - //look for a PNG file inside + Error err = _load_data(p_path, images, mipmap_limit); + if (err) + return err; - int mipmaps = f->get_32(); - Vector> mipmap_images; - - for (int i = 0; i < mipmaps; i++) { - uint32_t size = f->get_32(); - - Vector pv; - pv.resize(size); - { - uint8_t *w = pv.ptrw(); - f->get_buffer(w, size); - } - - Ref img = Image::lossless_unpacker(pv); - - if (img.is_null() || img->empty() || format != img->get_format()) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - - mipmap_images.push_back(img); - } - - if (mipmap_images.size() == 1) { - - image = mipmap_images[0]; - - } else { - int total_size = Image::get_image_data_size(tw, th, format, true); - Vector img_data; - img_data.resize(total_size); - - { - uint8_t *w = img_data.ptrw(); - - int ofs = 0; - for (int i = 0; i < mipmap_images.size(); i++) { - - Vector id = mipmap_images[i]->get_data(); - int len = id.size(); - const uint8_t *r = id.ptr(); - copymem(&w[ofs], r, len); - ofs += len; - } - } - - image->create(tw, th, true, format, img_data); - if (image->empty()) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - } - - } else { - - //look for regular format - - int total_size = Image::get_image_data_size(tw, th, format, use_mipmaps); - - Vector img_data; - img_data.resize(total_size); - - { - uint8_t *w = img_data.ptrw(); - int bytes = f->get_buffer(w, total_size); - if (bytes != total_size) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - } - - image->create(tw, th, use_mipmaps, format, img_data); - } - - images.push_back(image); - } - - Error err = lt->create_from_images(images); - if (err != OK) { - *r_error = err; - return RES(); + if (texture.is_valid()) { + RID new_texture = RS::get_singleton()->texture_2d_layered_create(images, RS::TextureLayeredType(layered_type)); + RS::get_singleton()->texture_replace(texture, new_texture); } else { - - if (r_error) - *r_error = OK; + texture = RS::get_singleton()->texture_2d_layered_create(images, RS::TextureLayeredType(layered_type)); } - return lt; + w = images[0]->get_width(); + h = images[0]->get_height(); + mipmaps = images[0]->has_mipmaps(); + format = images[0]->get_format(); + layers = images.size(); + + path_to_file = p_path; + + if (get_path() == String()) { + //temporarily set path if no path set for resource, helps find errors + RenderingServer::get_singleton()->texture_set_path(texture, p_path); + } + + _change_notify(); + emit_changed(); + return OK; +} +String StreamTextureLayered::get_load_path() const { + + return path_to_file; } -void ResourceFormatLoaderTextureLayered::get_recognized_extensions(List *p_extensions) const { +int StreamTextureLayered::get_width() const { - p_extensions->push_back("cube"); - p_extensions->push_back("cubearr"); - p_extensions->push_back("tex2darr"); + return w; } -bool ResourceFormatLoaderTextureLayered::handles_type(const String &p_type) const { - return p_type == "Texture2DArray" || p_type == "Cubemap" || p_type == "CubemapArray"; -} -String ResourceFormatLoaderTextureLayered::get_resource_type(const String &p_path) const { +int StreamTextureLayered::get_height() const { - if (p_path.get_extension().to_lower() == "cube") - return "Cubemap"; - if (p_path.get_extension().to_lower() == "cubearr") - return "CubemapArray"; - if (p_path.get_extension().to_lower() == "tex2darr") - return "Texture2DArray"; + return h; +} +int StreamTextureLayered::get_layers() const { + + return layers; +} +bool StreamTextureLayered::has_mipmaps() const { + return mipmaps; +} + +TextureLayered::LayeredType StreamTextureLayered::get_layered_type() const { + + return layered_type; +} + +RID StreamTextureLayered::get_rid() const { + + if (!texture.is_valid()) { + texture = RS::get_singleton()->texture_2d_layered_placeholder_create(RS::TextureLayeredType(layered_type)); + } + return texture; +} + +Ref StreamTextureLayered::get_layer_data(int p_layer) const { + + if (texture.is_valid()) { + return RS::get_singleton()->texture_2d_layer_get(texture, p_layer); + } else { + return Ref(); + } +} + +void StreamTextureLayered::reload_from_file() { + + String path = get_path(); + if (!path.is_resource_file()) + return; + + path = ResourceLoader::path_remap(path); //remap for translation + path = ResourceLoader::import_remap(path); //remap for import + if (!path.is_resource_file()) + return; + + load(path); +} + +void StreamTextureLayered::_validate_property(PropertyInfo &property) const { +} + +void StreamTextureLayered::_bind_methods() { + + ClassDB::bind_method(D_METHOD("load", "path"), &StreamTextureLayered::load); + ClassDB::bind_method(D_METHOD("get_load_path"), &StreamTextureLayered::get_load_path); + + ADD_PROPERTY(PropertyInfo(Variant::STRING, "load_path", PROPERTY_HINT_FILE, "*.stex"), "load", "get_load_path"); +} + +StreamTextureLayered::StreamTextureLayered(LayeredType p_type) { + + layered_type = p_type; + format = Image::FORMAT_MAX; + w = 0; + h = 0; + layers = 0; + mipmaps = false; +} + +StreamTextureLayered::~StreamTextureLayered() { + + if (texture.is_valid()) { + RS::get_singleton()->free(texture); + } +} + +///////////////////////////////////////////////// + +RES ResourceFormatLoaderStreamTextureLayered::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { + + Ref st; + if (p_path.get_extension().to_lower() == "stexarray") { + Ref s; + s.instance(); + st = s; + } else if (p_path.get_extension().to_lower() == "scube") { + Ref s; + s.instance(); + st = s; + } else if (p_path.get_extension().to_lower() == "scubearray") { + Ref s; + s.instance(); + st = s; + } else { + if (r_error) { + *r_error = ERR_FILE_UNRECOGNIZED; + } + return RES(); + } + Error err = st->load(p_path); + if (r_error) + *r_error = err; + if (err != OK) + return RES(); + + return st; +} + +void ResourceFormatLoaderStreamTextureLayered::get_recognized_extensions(List *p_extensions) const { + + p_extensions->push_back("stexarray"); + p_extensions->push_back("scube"); + p_extensions->push_back("scubearray"); +} +bool ResourceFormatLoaderStreamTextureLayered::handles_type(const String &p_type) const { + return p_type == "StreamTexture2DArray" || p_type == "StreamCubemap" || p_type == "StreamCubemapArray"; +} +String ResourceFormatLoaderStreamTextureLayered::get_resource_type(const String &p_path) const { + + if (p_path.get_extension().to_lower() == "stexarray") + return "StreamTexture2DArray"; + if (p_path.get_extension().to_lower() == "scube") + return "StreamCubemap"; + if (p_path.get_extension().to_lower() == "scubearray") + return "StreamCubemapArray"; return ""; } diff --git a/scene/resources/texture.h b/scene/resources/texture.h index 5d5f438eba..5d4131ec4c 100644 --- a/scene/resources/texture.h +++ b/scene/resources/texture.h @@ -132,9 +132,9 @@ public: ~ImageTexture(); }; -class StreamTexture : public Texture2D { +class StreamTexture2D : public Texture2D { - GDCLASS(StreamTexture, Texture2D); + GDCLASS(StreamTexture2D, Texture2D); public: enum DataFormat { @@ -181,8 +181,8 @@ protected: public: static Ref load_image_from_file(FileAccess *p_file, int p_size_limit); - typedef void (*TextureFormatRequestCallback)(const Ref &); - typedef void (*TextureFormatRoughnessRequestCallback)(const Ref &, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel); + typedef void (*TextureFormatRequestCallback)(const Ref &); + typedef void (*TextureFormatRoughnessRequestCallback)(const Ref &, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel); static TextureFormatRequestCallback request_3d_callback; static TextureFormatRoughnessRequestCallback request_roughness_callback; @@ -207,11 +207,11 @@ public: virtual Ref get_data() const; - StreamTexture(); - ~StreamTexture(); + StreamTexture2D(); + ~StreamTexture2D(); }; -class ResourceFormatLoaderStreamTexture : public ResourceFormatLoader { +class ResourceFormatLoaderStreamTexture2D : public ResourceFormatLoader { public: virtual RES load(const String &p_path, const String &p_original_path = "", Error *r_error = nullptr, bool p_use_sub_threads = false, float *r_progress = nullptr, bool p_no_cache = false); virtual void get_recognized_extensions(List *p_extensions) const; @@ -349,10 +349,34 @@ public: }; class TextureLayered : public Texture { - GDCLASS(TextureLayered, Texture); - RS::TextureLayeredType layered_type; +protected: + static void _bind_methods(); + +public: + enum LayeredType { + LAYERED_TYPE_2D_ARRAY, + LAYERED_TYPE_CUBEMAP, + LAYERED_TYPE_CUBEMAP_ARRAY + }; + + virtual Image::Format get_format() const = 0; + virtual LayeredType get_layered_type() const = 0; + virtual int get_width() const = 0; + virtual int get_height() const = 0; + virtual int get_layers() const = 0; + virtual bool has_mipmaps() const = 0; + virtual Ref get_layer_data(int p_layer) const = 0; +}; + +VARIANT_ENUM_CAST(TextureLayered::LayeredType) + +class ImageTextureLayered : public TextureLayered { + + GDCLASS(ImageTextureLayered, TextureLayered); + + LayeredType layered_type; mutable RID texture; Image::Format format; @@ -370,57 +394,137 @@ protected: static void _bind_methods(); public: - Image::Format get_format() const; - uint32_t get_width() const; - uint32_t get_height() const; - uint32_t get_layers() const; - bool has_mipmaps() const; + virtual Image::Format get_format() const; + virtual int get_width() const; + virtual int get_height() const; + virtual int get_layers() const; + virtual bool has_mipmaps() const; + virtual LayeredType get_layered_type() const; Error create_from_images(Vector> p_images); void update_layer(const Ref &p_image, int p_layer); - Ref get_layer_data(int p_layer) const; + virtual Ref get_layer_data(int p_layer) const; virtual RID get_rid() const; virtual void set_path(const String &p_path, bool p_take_over = false); - TextureLayered(RS::TextureLayeredType p_layered_type); - ~TextureLayered(); + ImageTextureLayered(LayeredType p_layered_type); + ~ImageTextureLayered(); }; -class Texture2DArray : public TextureLayered { +class Texture2DArray : public ImageTextureLayered { - GDCLASS(Texture2DArray, TextureLayered) + GDCLASS(Texture2DArray, ImageTextureLayered) public: Texture2DArray() : - TextureLayered(RS::TEXTURE_LAYERED_2D_ARRAY) {} + ImageTextureLayered(LAYERED_TYPE_2D_ARRAY) {} }; -class Cubemap : public TextureLayered { +class Cubemap : public ImageTextureLayered { - GDCLASS(Cubemap, TextureLayered); + GDCLASS(Cubemap, ImageTextureLayered); public: Cubemap() : - TextureLayered(RS::TEXTURE_LAYERED_CUBEMAP) {} + ImageTextureLayered(LAYERED_TYPE_CUBEMAP) {} }; -class CubemapArray : public TextureLayered { +class CubemapArray : public ImageTextureLayered { - GDCLASS(CubemapArray, TextureLayered); + GDCLASS(CubemapArray, ImageTextureLayered); public: CubemapArray() : - TextureLayered(RS::TEXTURE_LAYERED_CUBEMAP_ARRAY) {} + ImageTextureLayered(LAYERED_TYPE_CUBEMAP_ARRAY) {} }; -class ResourceFormatLoaderTextureLayered : public ResourceFormatLoader { +class StreamTextureLayered : public TextureLayered { + + GDCLASS(StreamTextureLayered, TextureLayered); + public: - enum Compression { - COMPRESSION_LOSSLESS, - COMPRESSION_VRAM, - COMPRESSION_UNCOMPRESSED + enum DataFormat { + DATA_FORMAT_IMAGE, + DATA_FORMAT_LOSSLESS, + DATA_FORMAT_LOSSY, + DATA_FORMAT_BASIS_UNIVERSAL, }; + enum { + FORMAT_VERSION = 1 + }; + + enum FormatBits { + FORMAT_MASK_IMAGE_FORMAT = (1 << 20) - 1, + FORMAT_BIT_LOSSLESS = 1 << 20, + FORMAT_BIT_LOSSY = 1 << 21, + FORMAT_BIT_STREAM = 1 << 22, + FORMAT_BIT_HAS_MIPMAPS = 1 << 23, + }; + +private: + Error _load_data(const String &p_path, Vector> &images, int &mipmap_limit, int p_size_limit = 0); + String path_to_file; + mutable RID texture; + Image::Format format; + int w, h, layers; + bool mipmaps; + LayeredType layered_type; + + virtual void reload_from_file(); + +protected: + static void _bind_methods(); + void _validate_property(PropertyInfo &property) const; + +public: + Image::Format get_format() const; + Error load(const String &p_path); + String get_load_path() const; + virtual LayeredType get_layered_type() const; + + int get_width() const; + int get_height() const; + int get_layers() const; + virtual bool has_mipmaps() const; + virtual RID get_rid() const; + + virtual void set_path(const String &p_path, bool p_take_over); + + virtual Ref get_layer_data(int p_layer) const; + + StreamTextureLayered(LayeredType p_layered_type); + ~StreamTextureLayered(); +}; + +class StreamTexture2DArray : public StreamTextureLayered { + + GDCLASS(StreamTexture2DArray, StreamTextureLayered) +public: + StreamTexture2DArray() : + StreamTextureLayered(LAYERED_TYPE_2D_ARRAY) {} +}; + +class StreamCubemap : public StreamTextureLayered { + + GDCLASS(StreamCubemap, StreamTextureLayered); + +public: + StreamCubemap() : + StreamTextureLayered(LAYERED_TYPE_CUBEMAP) {} +}; + +class StreamCubemapArray : public StreamTextureLayered { + + GDCLASS(StreamCubemapArray, StreamTextureLayered); + +public: + StreamCubemapArray() : + StreamTextureLayered(LAYERED_TYPE_CUBEMAP_ARRAY) {} +}; + +class ResourceFormatLoaderStreamTextureLayered : public ResourceFormatLoader { +public: virtual RES load(const String &p_path, const String &p_original_path = "", Error *r_error = nullptr, bool p_use_sub_threads = false, float *r_progress = nullptr, bool p_no_cache = false); virtual void get_recognized_extensions(List *p_extensions) const; virtual bool handles_type(const String &p_type) const; diff --git a/scene/scene_string_names.cpp b/scene/scene_string_names.cpp index 5e3f8b803b..e761354cc9 100644 --- a/scene/scene_string_names.cpp +++ b/scene/scene_string_names.cpp @@ -205,4 +205,9 @@ SceneStringNames::SceneStringNames() { shader_overrides_group = StaticCString::create("_shader_overrides_group_"); shader_overrides_group_active = StaticCString::create("_shader_overrides_group_active_"); + +#ifndef DISABLE_DEPRECATED + use_in_baked_light = StaticCString::create("use_in_baked_light"); + use_dynamic_gi = StaticCString::create("use_dynamic_gi"); +#endif } diff --git a/scene/scene_string_names.h b/scene/scene_string_names.h index c5de10a6f6..c5c98ba9e5 100644 --- a/scene/scene_string_names.h +++ b/scene/scene_string_names.h @@ -210,6 +210,10 @@ public: StringName shader_overrides_group; StringName shader_overrides_group_active; +#ifndef DISABLE_DEPRECATED + StringName use_in_baked_light; + StringName use_dynamic_gi; +#endif enum { MAX_MATERIALS = 32 }; diff --git a/servers/rendering/rasterizer.h b/servers/rendering/rasterizer.h index 099e155553..48ee6e02d0 100644 --- a/servers/rendering/rasterizer.h +++ b/servers/rendering/rasterizer.h @@ -57,6 +57,7 @@ public: virtual void sky_set_radiance_size(RID p_sky, int p_radiance_size) = 0; virtual void sky_set_mode(RID p_sky, RS::SkyMode p_samples) = 0; virtual void sky_set_material(RID p_sky, RID p_material) = 0; + virtual Ref sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size) = 0; /* ENVIRONMENT API */ @@ -94,6 +95,8 @@ public: virtual void environment_set_fog_depth(RID p_env, bool p_enable, float p_depth_begin, float p_depth_end, float p_depth_curve, bool p_transmit, float p_transmit_curve) = 0; virtual void environment_set_fog_height(RID p_env, bool p_enable, float p_min_height, float p_max_height, float p_height_curve) = 0; + virtual Ref environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size) = 0; + virtual bool is_environment(RID p_env) const = 0; virtual RS::EnvironmentBG environment_get_background(RID p_env) const = 0; virtual int environment_get_canvas_max_layer(RID p_env) const = 0; @@ -160,9 +163,11 @@ public: SelfList dependency_item; - InstanceBase *lightmap_capture; - RID lightmap; - Vector lightmap_capture_data; //in a array (12 values) to avoid wasting space if unused. Alpha is unused, but needed to send to shader + InstanceBase *lightmap; + Rect2 lightmap_uv_scale; + int lightmap_slice_index; + uint32_t lightmap_cull_index; + Vector lightmap_sh; //spherical harmonic AABB aabb; AABB transformed_aabb; @@ -178,8 +183,8 @@ public: bool instance_allocated_shader_parameters = false; int32_t instance_allocated_shader_parameters_offset = -1; - virtual void dependency_deleted(RID p_dependency) = 0; - virtual void dependency_changed(bool p_aabb, bool p_dependencies) = 0; + virtual void dependency_deleted(RID p_dependency) {} + virtual void dependency_changed(bool p_aabb, bool p_dependencies) {} Set dependencies; @@ -233,7 +238,9 @@ public: baked_light = false; dynamic_gi = false; redraw_if_visible = false; - lightmap_capture = nullptr; + lightmap_slice_index = 0; + lightmap = nullptr; + lightmap_cull_index = 0; } virtual ~InstanceBase() { @@ -268,7 +275,7 @@ public: virtual bool gi_probe_needs_update(RID p_probe) const = 0; virtual void gi_probe_update(RID p_probe, bool p_update_light_instances, const Vector &p_light_instances, int p_dynamic_object_count, InstanceBase **p_dynamic_objects) = 0; - virtual void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) = 0; + virtual void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) = 0; virtual void render_shadow(RID p_light, RID p_shadow_atlas, int p_pass, InstanceBase **p_cull_result, int p_cull_count) = 0; virtual void render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) = 0; @@ -286,6 +293,8 @@ public: virtual void sub_surface_scattering_set_quality(RS::SubSurfaceScatteringQuality p_quality) = 0; virtual void sub_surface_scattering_set_scale(float p_scale, float p_depth_scale) = 0; + virtual TypedArray bake_render_uv2(RID p_base, const Vector &p_material_overrides, const Size2i &p_image_size) = 0; + virtual bool free(RID p_rid) = 0; virtual void update() = 0; @@ -311,7 +320,7 @@ public: //these two APIs can be used together or in combination with the others. virtual RID texture_2d_placeholder_create() = 0; - virtual RID texture_2d_layered_placeholder_create() = 0; + virtual RID texture_2d_layered_placeholder_create(RenderingServer::TextureLayeredType p_layered_type) = 0; virtual RID texture_3d_placeholder_create() = 0; virtual Ref texture_2d_get(RID p_texture) const = 0; @@ -593,29 +602,21 @@ public: /* LIGHTMAP CAPTURE */ - struct LightmapCaptureOctree { + virtual RID lightmap_create() = 0; - enum { - CHILD_EMPTY = 0xFFFFFFFF - }; - - uint16_t light[6][3]; //anisotropic light - float alpha; - uint32_t children[8]; - }; - - virtual RID lightmap_capture_create() = 0; - virtual void lightmap_capture_set_bounds(RID p_capture, const AABB &p_bounds) = 0; - virtual AABB lightmap_capture_get_bounds(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree(RID p_capture, const Vector &p_octree) = 0; - virtual Vector lightmap_capture_get_octree(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree_cell_transform(RID p_capture, const Transform &p_xform) = 0; - virtual Transform lightmap_capture_get_octree_cell_transform(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree_cell_subdiv(RID p_capture, int p_subdiv) = 0; - virtual int lightmap_capture_get_octree_cell_subdiv(RID p_capture) const = 0; - virtual void lightmap_capture_set_energy(RID p_capture, float p_energy) = 0; - virtual float lightmap_capture_get_energy(RID p_capture) const = 0; - virtual const Vector *lightmap_capture_get_octree_ptr(RID p_capture) const = 0; + virtual void lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics) = 0; + virtual void lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds) = 0; + virtual void lightmap_set_probe_interior(RID p_lightmap, bool p_interior) = 0; + virtual void lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) = 0; + virtual PackedVector3Array lightmap_get_probe_capture_points(RID p_lightmap) const = 0; + virtual PackedColorArray lightmap_get_probe_capture_sh(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const = 0; + virtual AABB lightmap_get_aabb(RID p_lightmap) const = 0; + virtual void lightmap_tap_sh_light(RID p_lightmap, const Vector3 &p_point, Color *r_sh) = 0; + virtual bool lightmap_is_interior(RID p_lightmap) const = 0; + virtual void lightmap_set_probe_capture_update_speed(float p_speed) = 0; + virtual float lightmap_get_probe_capture_update_speed() const = 0; /* PARTICLES */ @@ -1370,6 +1371,8 @@ public: virtual void end_frame(bool p_swap_buffers) = 0; virtual void finalize() = 0; + virtual uint64_t get_frame_number() const = 0; + virtual float get_frame_delta_time() const = 0; virtual bool is_low_end() const = 0; diff --git a/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp index 956bf54d01..3505b18c8a 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp @@ -1539,8 +1539,8 @@ void RasterizerCanvasRD::canvas_render_items(RID p_to_render_target, Item *p_ite } } - if (md->last_frame != RasterizerRD::get_frame_number()) { - md->last_frame = RasterizerRD::get_frame_number(); + if (md->last_frame != RasterizerRD::singleton->get_frame_number()) { + md->last_frame = RasterizerRD::singleton->get_frame_number(); if (!RD::get_singleton()->uniform_set_is_valid(md->uniform_set)) { // uniform set may be gone because a dependency was erased. In this case, it will happen // if a texture is deleted, so just re-create it. diff --git a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp index d469dd97ca..ed25cc4139 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp @@ -282,6 +282,30 @@ void RasterizerEffectsRD::copy_to_rect(RID p_source_rd_texture, RID p_dest_textu RD::get_singleton()->compute_list_end(); } +void RasterizerEffectsRD::copy_cubemap_to_panorama(RID p_source_cube, RID p_dest_panorama, const Size2i &p_panorama_size, float p_lod, bool p_is_array) { + + zeromem(©.push_constant, sizeof(CopyPushConstant)); + + copy.push_constant.section[0] = 0; + copy.push_constant.section[1] = 0; + copy.push_constant.section[2] = p_panorama_size.width; + copy.push_constant.section[3] = p_panorama_size.height; + copy.push_constant.target[0] = 0; + copy.push_constant.target[1] = 0; + copy.push_constant.camera_z_far = p_lod; + + int32_t x_groups = (p_panorama_size.width - 1) / 8 + 1; + int32_t y_groups = (p_panorama_size.height - 1) / 8 + 1; + + RD::ComputeListID compute_list = RD::get_singleton()->compute_list_begin(); + RD::get_singleton()->compute_list_bind_compute_pipeline(compute_list, copy.pipelines[p_is_array ? COPY_MODE_CUBE_ARRAY_TO_PANORAMA : COPY_MODE_CUBE_TO_PANORAMA]); + RD::get_singleton()->compute_list_bind_uniform_set(compute_list, _get_compute_uniform_set_from_texture(p_source_cube), 0); + RD::get_singleton()->compute_list_bind_uniform_set(compute_list, _get_uniform_set_from_image(p_dest_panorama), 3); + RD::get_singleton()->compute_list_set_push_constant(compute_list, ©.push_constant, sizeof(CopyPushConstant)); + RD::get_singleton()->compute_list_dispatch(compute_list, x_groups, y_groups, 1); + RD::get_singleton()->compute_list_end(); +} + void RasterizerEffectsRD::copy_depth_to_rect_and_linearize(RID p_source_rd_texture, RID p_dest_texture, const Rect2i &p_rect, bool p_flip_y, float p_z_near, float p_z_far) { zeromem(©.push_constant, sizeof(CopyPushConstant)); @@ -1202,7 +1226,9 @@ void RasterizerEffectsRD::render_sky(RD::DrawListID p_list, float p_time, RID p_ RD::get_singleton()->draw_list_bind_render_pipeline(draw_list, p_pipeline->get_render_pipeline(RD::INVALID_ID, fb_format)); RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_samplers, 0); - RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_uniform_set, 1); + if (p_uniform_set.is_valid()) { //material may not have uniform set + RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_uniform_set, 1); + } RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_texture_set, 2); RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_lights, 3); @@ -1226,6 +1252,8 @@ RasterizerEffectsRD::RasterizerEffectsRD() { copy_modes.push_back("\n#define MODE_SIMPLE_COPY_DEPTH\n"); copy_modes.push_back("\n#define MODE_MIPMAP\n"); copy_modes.push_back("\n#define MODE_LINEARIZE_DEPTH_COPY\n"); + copy_modes.push_back("\n#define MODE_CUBEMAP_TO_PANORAMA\n"); + copy_modes.push_back("\n#define MODE_CUBEMAP_ARRAY_TO_PANORAMA\n"); copy.shader.initialize(copy_modes); zeromem(©.push_constant, sizeof(CopyPushConstant)); diff --git a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h index 531591442b..1b16648ca6 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h @@ -66,6 +66,8 @@ class RasterizerEffectsRD { COPY_MODE_SIMPLY_COPY_DEPTH, COPY_MODE_MIPMAP, COPY_MODE_LINEARIZE_DEPTH, + COPY_MODE_CUBE_TO_PANORAMA, + COPY_MODE_CUBE_ARRAY_TO_PANORAMA, COPY_MODE_MAX, }; @@ -564,6 +566,7 @@ class RasterizerEffectsRD { public: void copy_to_fb_rect(RID p_source_rd_texture, RID p_dest_framebuffer, const Rect2i &p_rect, bool p_flip_y = false, bool p_force_luminance = false, bool p_alpha_to_zero = false); void copy_to_rect(RID p_source_rd_texture, RID p_dest_texture, const Rect2i &p_rect, bool p_flip_y = false, bool p_force_luminance = false, bool p_all_source = false, bool p_8_bit_dst = false); + void copy_cubemap_to_panorama(RID p_source_cube, RID p_dest_panorama, const Size2i &p_panorama_size, float p_lod, bool p_is_array); void copy_depth_to_rect(RID p_source_rd_texture, RID p_dest_framebuffer, const Rect2i &p_rect, bool p_flip_y = false); void copy_depth_to_rect_and_linearize(RID p_source_rd_texture, RID p_dest_texture, const Rect2i &p_rect, bool p_flip_y, float p_z_near, float p_z_far); void copy_to_atlas_fb(RID p_source_rd_texture, RID p_dest_framebuffer, const Rect2 &p_uv_rect, RD::DrawListID p_draw_list, bool p_flip_y = false, bool p_panorama = false); diff --git a/servers/rendering/rasterizer_rd/rasterizer_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_rd.cpp index 4c92912e9c..4267a087b6 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_rd.cpp @@ -79,6 +79,7 @@ void RasterizerRD::blit_render_targets_to_screen(DisplayServer::WindowID p_scree void RasterizerRD::begin_frame(double frame_step) { frame++; + delta = frame_step; time += frame_step; double time_roll_over = GLOBAL_GET("rendering/limits/time/time_rollover_secs"); @@ -157,7 +158,7 @@ void RasterizerRD::initialize() { } ThreadWorkPool RasterizerRD::thread_work_pool; -uint32_t RasterizerRD::frame = 1; +uint64_t RasterizerRD::frame = 1; void RasterizerRD::finalize() { @@ -173,7 +174,10 @@ void RasterizerRD::finalize() { RD::get_singleton()->free(copy_viewports_sampler); } +RasterizerRD *RasterizerRD::singleton = nullptr; + RasterizerRD::RasterizerRD() { + singleton = this; thread_work_pool.init(); time = 0; diff --git a/servers/rendering/rasterizer_rd/rasterizer_rd.h b/servers/rendering/rasterizer_rd/rasterizer_rd.h index 756b9499ca..cb53a531ac 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_rd.h @@ -53,8 +53,9 @@ protected: Map render_target_descriptors; double time; + float delta; - static uint32_t frame; + static uint64_t frame; public: RasterizerStorage *get_storage() { return storage; } @@ -71,7 +72,8 @@ public: void end_frame(bool p_swap_buffers); void finalize(); - static _ALWAYS_INLINE_ uint64_t get_frame_number() { return frame; } + _ALWAYS_INLINE_ uint64_t get_frame_number() const { return frame; } + _ALWAYS_INLINE_ float get_frame_delta_time() const { return delta; } static Error is_viable() { return OK; @@ -89,6 +91,7 @@ public: static ThreadWorkPool thread_work_pool; + static RasterizerRD *singleton; RasterizerRD(); ~RasterizerRD() {} }; diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp index 6986f82065..3f0062b2ae 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp @@ -67,18 +67,18 @@ static _FORCE_INLINE_ void store_basis_3x4(const Basis &p_mtx, float *p_array) { p_array[11] = 0; } -static _FORCE_INLINE_ void store_transform_3x3(const Transform &p_mtx, float *p_array) { - p_array[0] = p_mtx.basis.elements[0][0]; - p_array[1] = p_mtx.basis.elements[1][0]; - p_array[2] = p_mtx.basis.elements[2][0]; +static _FORCE_INLINE_ void store_transform_3x3(const Basis &p_mtx, float *p_array) { + p_array[0] = p_mtx.elements[0][0]; + p_array[1] = p_mtx.elements[1][0]; + p_array[2] = p_mtx.elements[2][0]; p_array[3] = 0; - p_array[4] = p_mtx.basis.elements[0][1]; - p_array[5] = p_mtx.basis.elements[1][1]; - p_array[6] = p_mtx.basis.elements[2][1]; + p_array[4] = p_mtx.elements[0][1]; + p_array[5] = p_mtx.elements[1][1]; + p_array[6] = p_mtx.elements[2][1]; p_array[7] = 0; - p_array[8] = p_mtx.basis.elements[0][2]; - p_array[9] = p_mtx.basis.elements[1][2]; - p_array[10] = p_mtx.basis.elements[2][2]; + p_array[8] = p_mtx.elements[0][2]; + p_array[9] = p_mtx.elements[1][2]; + p_array[10] = p_mtx.elements[2][2]; p_array[11] = 0; } @@ -841,6 +841,8 @@ bool RasterizerSceneHighEndRD::free(RID p_rid) { void RasterizerSceneHighEndRD::_fill_instances(RenderList::Element **p_elements, int p_element_count, bool p_for_depth) { + uint32_t lightmap_captures_used = 0; + for (int i = 0; i < p_element_count; i++) { const RenderList::Element *e = p_elements[i]; @@ -898,6 +900,7 @@ void RasterizerSceneHighEndRD::_fill_instances(RenderList::Element **p_elements, if (written == 0) { id.gi_offset = index; + id.flags |= INSTANCE_DATA_FLAG_USE_GIPROBE; written = 1; } else { id.gi_offset = index << 16; @@ -910,17 +913,53 @@ void RasterizerSceneHighEndRD::_fill_instances(RenderList::Element **p_elements, } else if (written == 1) { id.gi_offset |= 0xFFFF0000; } + } else if (e->instance->lightmap) { + + int32_t lightmap_index = storage->lightmap_get_array_index(e->instance->lightmap->base); + if (lightmap_index >= 0) { + id.gi_offset = lightmap_index; + id.gi_offset |= e->instance->lightmap_slice_index << 12; + id.gi_offset |= e->instance->lightmap_cull_index << 20; + id.lightmap_uv_scale[0] = e->instance->lightmap_uv_scale.position.x; + id.lightmap_uv_scale[1] = e->instance->lightmap_uv_scale.position.y; + id.lightmap_uv_scale[2] = e->instance->lightmap_uv_scale.size.width; + id.lightmap_uv_scale[3] = e->instance->lightmap_uv_scale.size.height; + id.flags |= INSTANCE_DATA_FLAG_USE_LIGHTMAP; + if (storage->lightmap_uses_spherical_harmonics(e->instance->lightmap->base)) { + id.flags |= INSTANCE_DATA_FLAG_USE_SH_LIGHTMAP; + } + } else { + id.gi_offset = 0xFFFFFFFF; + } + } else if (!e->instance->lightmap_sh.empty()) { + if (lightmap_captures_used < scene_state.max_lightmap_captures) { + + const Color *src_capture = e->instance->lightmap_sh.ptr(); + LightmapCaptureData &lcd = scene_state.lightmap_captures[lightmap_captures_used]; + for (int j = 0; j < 9; j++) { + lcd.sh[j * 4 + 0] = src_capture[j].r; + lcd.sh[j * 4 + 1] = src_capture[j].g; + lcd.sh[j * 4 + 2] = src_capture[j].b; + lcd.sh[j * 4 + 3] = src_capture[j].a; + } + id.flags |= INSTANCE_DATA_FLAG_USE_LIGHTMAP_CAPTURE; + id.gi_offset = lightmap_captures_used; + lightmap_captures_used++; + } } else { id.gi_offset = 0xFFFFFFFF; } } RD::get_singleton()->buffer_update(scene_state.instance_buffer, 0, sizeof(InstanceData) * p_element_count, scene_state.instances, true); + if (lightmap_captures_used) { + RD::get_singleton()->buffer_update(scene_state.lightmap_capture_buffer, 0, sizeof(LightmapCaptureData) * lightmap_captures_used, scene_state.lightmap_captures, true); + } } /// RENDERING /// -void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set) { +void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set, bool p_force_wireframe, const Vector2 &p_uv_offset) { RD::DrawListID draw_list = p_draw_list; RD::FramebufferFormatID framebuffer_format = p_framebuffer_Format; @@ -949,6 +988,8 @@ void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_l PushConstant push_constant; zeromem(&push_constant, sizeof(PushConstant)); + push_constant.bake_uv2_offset[0] = p_uv_offset.x; + push_constant.bake_uv2_offset[1] = p_uv_offset.y; for (int i = 0; i < p_element_count; i++) { @@ -961,7 +1002,7 @@ void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_l //find cull variant ShaderData::CullVariant cull_variant; - if ((p_pass_mode == PASS_MODE_SHADOW || p_pass_mode == PASS_MODE_SHADOW_DP) && e->instance->cast_shadows == RS::SHADOW_CASTING_SETTING_DOUBLE_SIDED) { + if (p_pass_mode == PASS_MODE_DEPTH_MATERIAL || ((p_pass_mode == PASS_MODE_SHADOW || p_pass_mode == PASS_MODE_SHADOW_DP) && e->instance->cast_shadows == RS::SHADOW_CASTING_SETTING_DOUBLE_SIDED)) { cull_variant = ShaderData::CULL_VARIANT_DOUBLE_SIDED; } else { bool mirror = e->instance->mirror; @@ -1080,7 +1121,7 @@ void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_l prev_index_array_rd = index_array_rd; } - RID pipeline_rd = pipeline->get_render_pipeline(vertex_format, framebuffer_format); + RID pipeline_rd = pipeline->get_render_pipeline(vertex_format, framebuffer_format, p_force_wireframe); if (pipeline_rd != prev_pipeline_rd) { // checking with prev shader does not make so much sense, as @@ -1255,6 +1296,7 @@ void RasterizerSceneHighEndRD::_setup_environment(RID p_environment, const Camer scene_state.ubo.use_ambient_cubemap = false; scene_state.ubo.use_reflection_cubemap = false; + scene_state.ubo.ssao_enabled = false; } scene_state.ubo.roughness_limiter_enabled = p_opaque_render_buffers && screen_space_roughness_limiter_is_active(); @@ -1271,8 +1313,6 @@ void RasterizerSceneHighEndRD::_add_geometry(InstanceBase *p_instance, uint32_t if (unlikely(get_debug_draw_mode() != RS::VIEWPORT_DEBUG_DRAW_DISABLED)) { if (get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_OVERDRAW) { m_src = overdraw_material; - } else if (get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME) { - m_src = wireframe_material; } else if (get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_LIGHTING) { m_src = default_material; } @@ -1374,7 +1414,7 @@ void RasterizerSceneHighEndRD::_add_geometry_with_material(InstanceBase *p_insta e->geometry_index = p_geometry_index; e->material_index = e->material->index; e->uses_instancing = e->instance->base_type == RS::INSTANCE_MULTIMESH; - e->uses_lightmap = e->instance->lightmap.is_valid(); + e->uses_lightmap = e->instance->lightmap != nullptr || !e->instance->lightmap_sh.empty(); e->uses_vct = e->instance->gi_probe_instances.size(); e->shader_index = e->shader_index; e->depth_layer = e->instance->depth_layer; @@ -1575,6 +1615,26 @@ void RasterizerSceneHighEndRD::_setup_reflections(RID *p_reflection_probe_cull_r } } +void RasterizerSceneHighEndRD::_setup_lightmaps(InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, const Transform &p_cam_transform) { + + uint32_t lightmaps_used = 0; + for (int i = 0; i < p_lightmap_cull_count; i++) { + if (i >= (int)scene_state.max_lightmaps) { + break; + } + + InstanceBase *lm = p_lightmap_cull_result[i]; + Basis to_lm = lm->transform.basis.inverse() * p_cam_transform.basis; + to_lm = to_lm.inverse().transposed(); //will transform normals + store_transform_3x3(to_lm, scene_state.lightmaps[i].normal_xform); + lm->lightmap_cull_index = i; + lightmaps_used++; + } + if (lightmaps_used > 0) { + RD::get_singleton()->buffer_update(scene_state.lightmap_buffer, 0, sizeof(LightmapData) * lightmaps_used, scene_state.lightmaps, true); + } +} + void RasterizerSceneHighEndRD::_setup_gi_probes(RID *p_gi_probe_probe_cull_result, int p_gi_probe_probe_cull_count, const Transform &p_camera_transform) { int index = 0; @@ -2118,7 +2178,7 @@ void RasterizerSceneHighEndRD::_setup_decals(const RID *p_decal_instances, int p } } -void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color) { +void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color) { RenderBufferDataHighEnd *render_buffer = nullptr; if (p_render_buffer.is_valid()) { @@ -2238,6 +2298,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor _setup_decals(p_decal_cull_result, p_decal_cull_count, p_cam_transform.affine_inverse()); _setup_reflections(p_reflection_probe_cull_result, p_reflection_probe_cull_count, p_cam_transform.affine_inverse(), p_environment); _setup_gi_probes(p_gi_probe_cull_result, p_gi_probe_cull_count, p_cam_transform); + _setup_lightmaps(p_lightmap_cull_result, p_lightmap_cull_count, p_cam_transform); _setup_environment(p_environment, p_cam_projection, p_cam_transform, p_reflection_probe, p_reflection_probe.is_valid(), screen_pixel_size, p_shadow_atlas, !p_reflection_probe.is_valid(), p_default_bg_color, p_cam_projection.get_z_near(), p_cam_projection.get_z_far(), false); cluster_builder.bake_cluster(); //bake to cluster @@ -2338,7 +2399,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor bool finish_depth = using_ssao; RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(depth_framebuffer, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, finish_depth ? RD::FINAL_ACTION_READ : RD::FINAL_ACTION_CONTINUE, depth_pass_clear); - _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(depth_framebuffer), render_list.elements, render_list.element_count, false, depth_pass_mode, render_buffer == nullptr, radiance_uniform_set, RID()); + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(depth_framebuffer), render_list.elements, render_list.element_count, false, depth_pass_mode, render_buffer == nullptr, radiance_uniform_set, RID(), get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME); RD::get_singleton()->draw_list_end(); if (render_buffer && render_buffer->msaa != RS::VIEWPORT_MSAA_DISABLED) { @@ -2394,7 +2455,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor RID framebuffer = using_separate_specular ? opaque_specular_framebuffer : opaque_framebuffer; RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(framebuffer, keep_color ? RD::INITIAL_ACTION_KEEP : RD::INITIAL_ACTION_CLEAR, will_continue_color ? RD::FINAL_ACTION_CONTINUE : RD::FINAL_ACTION_READ, depth_pre_pass ? (using_ssao ? RD::INITIAL_ACTION_KEEP : RD::INITIAL_ACTION_CONTINUE) : RD::INITIAL_ACTION_CLEAR, will_continue_depth ? RD::FINAL_ACTION_CONTINUE : RD::FINAL_ACTION_READ, c, 1.0, 0); - _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(framebuffer), render_list.elements, render_list.element_count, false, using_separate_specular ? PASS_MODE_COLOR_SPECULAR : PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set); + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(framebuffer), render_list.elements, render_list.element_count, false, using_separate_specular ? PASS_MODE_COLOR_SPECULAR : PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set, get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME); RD::get_singleton()->draw_list_end(); if (will_continue_color && using_separate_specular) { @@ -2472,7 +2533,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor { RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(alpha_framebuffer, can_continue_color ? RD::INITIAL_ACTION_CONTINUE : RD::INITIAL_ACTION_KEEP, RD::FINAL_ACTION_READ, can_continue_depth ? RD::INITIAL_ACTION_CONTINUE : RD::INITIAL_ACTION_KEEP, RD::FINAL_ACTION_READ); - _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(alpha_framebuffer), &render_list.elements[render_list.max_elements - render_list.alpha_element_count], render_list.alpha_element_count, false, PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set); + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(alpha_framebuffer), &render_list.elements[render_list.max_elements - render_list.alpha_element_count], render_list.alpha_element_count, false, PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set, get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME); RD::get_singleton()->draw_list_end(); } @@ -2517,13 +2578,14 @@ void RasterizerSceneHighEndRD::_render_shadow(RID p_framebuffer, InstanceBase ** } void RasterizerSceneHighEndRD::_render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) { - RENDER_TIMESTAMP("Setup Rendering Shadow"); + RENDER_TIMESTAMP("Setup Rendering Material"); _update_render_base_uniform_set(); render_pass++; scene_state.ubo.dual_paraboloid_side = 0; + scene_state.ubo.material_uv2_mode = true; _setup_environment(RID(), p_cam_projection, p_cam_transform, RID(), true, Vector2(1, 1), RID(), false, Color(), 0, 0); @@ -2554,6 +2616,67 @@ void RasterizerSceneHighEndRD::_render_material(const Transform &p_cam_transform } } +void RasterizerSceneHighEndRD::_render_uv2(InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) { + RENDER_TIMESTAMP("Setup Rendering UV2"); + + _update_render_base_uniform_set(); + + render_pass++; + + scene_state.ubo.dual_paraboloid_side = 0; + scene_state.ubo.material_uv2_mode = true; + + _setup_environment(RID(), CameraMatrix(), Transform(), RID(), true, Vector2(1, 1), RID(), false, Color(), 0, 0); + + render_list.clear(); + + PassMode pass_mode = PASS_MODE_DEPTH_MATERIAL; + _fill_render_list(p_cull_result, p_cull_count, pass_mode, true); + + _setup_view_dependant_uniform_set(RID(), RID()); + + RENDER_TIMESTAMP("Render Material"); + + render_list.sort_by_key(false); + + _fill_instances(render_list.elements, render_list.element_count, true); + + { + //regular forward for now + Vector clear; + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(p_framebuffer, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, clear, 1.0, 0, p_region); + + const int uv_offset_count = 9; + static const Vector2 uv_offsets[uv_offset_count] = { + Vector2(-1, 1), + Vector2(1, 1), + Vector2(1, -1), + Vector2(-1, -1), + Vector2(-1, 0), + Vector2(1, 0), + Vector2(0, -1), + Vector2(0, 1), + Vector2(0, 0), + + }; + + for (int i = 0; i < uv_offset_count; i++) { + Vector2 ofs = uv_offsets[i]; + ofs.x /= p_region.size.width; + ofs.y /= p_region.size.height; + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(p_framebuffer), render_list.elements, render_list.element_count, true, pass_mode, true, RID(), RID(), true, ofs); //first wireframe, for pseudo conservative + } + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(p_framebuffer), render_list.elements, render_list.element_count, true, pass_mode, true, RID(), RID(), false); //second regular triangles + + RD::get_singleton()->draw_list_end(); + } +} + void RasterizerSceneHighEndRD::_base_uniforms_changed() { if (!render_base_uniform_set.is_null() && RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set)) { @@ -2564,12 +2687,14 @@ void RasterizerSceneHighEndRD::_base_uniforms_changed() { void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { - if (render_base_uniform_set.is_null() || !RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set)) { + if (render_base_uniform_set.is_null() || !RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set) || (lightmap_texture_array_version != storage->lightmap_array_get_version())) { if (render_base_uniform_set.is_valid() && RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set)) { RD::get_singleton()->free(render_base_uniform_set); } + lightmap_texture_array_version = storage->lightmap_array_get_version(); + Vector uniforms; { @@ -2685,6 +2810,27 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; u.binding = 10; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.ids.push_back(scene_state.lightmap_buffer); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.binding = 11; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.ids = storage->lightmap_array_get_textures(); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.binding = 12; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.ids.push_back(scene_state.lightmap_capture_buffer); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.binding = 13; u.type = RD::UNIFORM_TYPE_TEXTURE; RID decal_atlas = storage->decal_atlas_get_texture(); u.ids.push_back(decal_atlas); @@ -2692,7 +2838,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { } { RD::Uniform u; - u.binding = 11; + u.binding = 14; u.type = RD::UNIFORM_TYPE_TEXTURE; RID decal_atlas = storage->decal_atlas_get_texture_srgb(); u.ids.push_back(decal_atlas); @@ -2700,7 +2846,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { } { RD::Uniform u; - u.binding = 12; + u.binding = 15; u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; u.ids.push_back(scene_state.decal_buffer); uniforms.push_back(u); @@ -2708,14 +2854,14 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; - u.binding = 13; + u.binding = 16; u.type = RD::UNIFORM_TYPE_TEXTURE; u.ids.push_back(cluster_builder.get_cluster_texture()); uniforms.push_back(u); } { RD::Uniform u; - u.binding = 14; + u.binding = 17; u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; u.ids.push_back(cluster_builder.get_cluster_indices_buffer()); uniforms.push_back(u); @@ -2723,7 +2869,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; - u.binding = 15; + u.binding = 18; u.type = RD::UNIFORM_TYPE_TEXTURE; if (directional_shadow_get_texture().is_valid()) { u.ids.push_back(directional_shadow_get_texture()); @@ -2736,7 +2882,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; - u.binding = 16; + u.binding = 19; u.ids.push_back(storage->global_variables_get_storage_buffer()); uniforms.push_back(u); } @@ -2951,7 +3097,21 @@ RasterizerSceneHighEndRD::RasterizerSceneHighEndRD(RasterizerStorageRD *p_storag scene_state.gi_probe_buffer = RD::get_singleton()->uniform_buffer_create(sizeof(GIProbeData) * scene_state.max_gi_probes); defines += "\n#define MAX_GI_PROBES " + itos(scene_state.max_gi_probes) + "\n"; } + { + //lightmaps + scene_state.max_lightmaps = storage->lightmap_array_get_size(); + defines += "\n#define MAX_LIGHTMAP_TEXTURES " + itos(scene_state.max_lightmaps) + "\n"; + defines += "\n#define MAX_LIGHTMAPS " + itos(scene_state.max_lightmaps) + "\n"; + scene_state.lightmaps = memnew_arr(LightmapData, scene_state.max_lightmaps); + scene_state.lightmap_buffer = RD::get_singleton()->storage_buffer_create(sizeof(LightmapData) * scene_state.max_lightmaps); + } + { + //captures + scene_state.max_lightmap_captures = 2048; + scene_state.lightmap_captures = memnew_arr(LightmapCaptureData, scene_state.max_lightmap_captures); + scene_state.lightmap_capture_buffer = RD::get_singleton()->storage_buffer_create(sizeof(LightmapCaptureData) * scene_state.max_lightmap_captures); + } { //decals scene_state.max_decals = MIN(1024 * 1024, uniform_max_size) / sizeof(DecalData); //1mb of decals uint32_t decal_buffer_size = scene_state.max_decals * sizeof(DecalData); @@ -2959,6 +3119,11 @@ RasterizerSceneHighEndRD::RasterizerSceneHighEndRD(RasterizerStorageRD *p_storag scene_state.decal_buffer = RD::get_singleton()->storage_buffer_create(decal_buffer_size); } + { + + defines += "\n#define MATERIAL_UNIFORM_SET " + itos(MATERIAL_UNIFORM_SET) + "\n"; + } + Vector shader_versions; shader_versions.push_back("\n#define MODE_RENDER_DEPTH\n"); shader_versions.push_back("\n#define MODE_RENDER_DEPTH\n#define MODE_DUAL_PARABOLOID\n"); diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h index a48e2e2259..e8736a0e53 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h @@ -193,7 +193,8 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { struct PushConstant { uint32_t index; - uint32_t pad[3]; + uint32_t pad; + float bake_uv2_offset[2]; }; /* Framebuffer */ @@ -241,6 +242,8 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { RID render_base_uniform_set; RID view_dependant_uniform_set; + uint64_t lightmap_texture_array_version = 0xFFFFFFFF; + virtual void _base_uniforms_changed(); void _render_buffers_clear_uniform_set(RenderBufferDataHighEnd *rb); virtual void _render_buffers_uniform_set_changed(RID p_render_buffers); @@ -331,6 +334,10 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t pad[1]; }; + struct LightmapData { + float normal_xform[12]; + }; + struct DecalData { float xform[16]; float inv_extents[3]; @@ -349,7 +356,15 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { float normal_fade; }; + struct LightmapCaptureData { + float sh[9 * 4]; + }; + enum { + INSTANCE_DATA_FLAG_USE_LIGHTMAP_CAPTURE = 1 << 8, + INSTANCE_DATA_FLAG_USE_LIGHTMAP = 1 << 9, + INSTANCE_DATA_FLAG_USE_SH_LIGHTMAP = 1 << 10, + INSTANCE_DATA_FLAG_USE_GIPROBE = 1 << 11, INSTANCE_DATA_FLAG_MULTIMESH = 1 << 12, INSTANCE_DATA_FLAG_MULTIMESH_FORMAT_2D = 1 << 13, INSTANCE_DATA_FLAG_MULTIMESH_HAS_COLOR = 1 << 14, @@ -366,6 +381,7 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t instance_uniforms_ofs; //instance_offset in instancing/skeleton buffer uint32_t gi_offset; //GI information when using lightmapping (VCT or lightmap) uint32_t mask; + float lightmap_uv_scale[4]; }; struct SceneState { @@ -418,6 +434,9 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t roughness_limiter_enabled; float ao_color[4]; + + uint32_t material_uv2_mode; + uint32_t pad_material[3]; }; UBO ubo; @@ -434,6 +453,10 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { RID gi_probe_buffer; uint32_t max_gi_probe_probes_per_instance; + LightmapData *lightmaps; + uint32_t max_lightmaps; + RID lightmap_buffer; + DecalData *decals; uint32_t max_decals; RID decal_buffer; @@ -446,6 +469,10 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t max_directional_lights; RID directional_light_buffer; + LightmapCaptureData *lightmap_captures; + uint32_t max_lightmap_captures; + RID lightmap_capture_buffer; + RID instance_buffer; InstanceData *instances; uint32_t max_instances; @@ -456,6 +483,7 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { bool used_sss = false; uint32_t current_shader_index = 0; uint32_t current_material_index = 0; + } scene_state; /* Render List */ @@ -632,18 +660,20 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { void _setup_decals(const RID *p_decal_instances, int p_decal_count, const Transform &p_camera_inverse_xform); void _setup_reflections(RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, const Transform &p_camera_inverse_transform, RID p_environment); void _setup_gi_probes(RID *p_gi_probe_probe_cull_result, int p_gi_probe_probe_cull_count, const Transform &p_camera_transform); + void _setup_lightmaps(InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, const Transform &p_cam_transform); void _fill_instances(RenderList::Element **p_elements, int p_element_count, bool p_for_depth); - void _render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set); + void _render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set, bool p_force_wireframe = false, const Vector2 &p_uv_offset = Vector2()); _FORCE_INLINE_ void _add_geometry(InstanceBase *p_instance, uint32_t p_surface, RID p_material, PassMode p_pass_mode, uint32_t p_geometry_index); _FORCE_INLINE_ void _add_geometry_with_material(InstanceBase *p_instance, uint32_t p_surface, MaterialData *p_material, RID p_material_rid, PassMode p_pass_mode, uint32_t p_geometry_index); void _fill_render_list(InstanceBase **p_cull_result, int p_cull_count, PassMode p_pass_mode, bool p_no_gi); protected: - virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color); + virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color); virtual void _render_shadow(RID p_framebuffer, InstanceBase **p_cull_result, int p_cull_count, const CameraMatrix &p_projection, const Transform &p_transform, float p_zfar, float p_bias, float p_normal_bias, bool p_use_dp, bool p_use_dp_flip, bool p_use_pancake); virtual void _render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region); + virtual void _render_uv2(InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region); public: virtual void set_time(double p_time, double p_step); diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp index ab669e7647..02221d1536 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp @@ -263,7 +263,47 @@ void RasterizerSceneRD::sky_set_material(RID p_sky, RID p_material) { Sky *sky = sky_owner.getornull(p_sky); ERR_FAIL_COND(!sky); sky->material = p_material; + _sky_invalidate(sky); } + +Ref RasterizerSceneRD::sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size) { + + Sky *sky = sky_owner.getornull(p_sky); + ERR_FAIL_COND_V(!sky, Ref()); + + _update_dirty_skys(); + + if (sky->radiance.is_valid()) { + + RD::TextureFormat tf; + tf.format = RD::DATA_FORMAT_R32G32B32A32_SFLOAT; + tf.width = p_size.width; + tf.height = p_size.height; + tf.usage_bits = RD::TEXTURE_USAGE_STORAGE_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + + RID rad_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + storage->get_effects()->copy_cubemap_to_panorama(sky->radiance, rad_tex, p_size, p_bake_irradiance ? roughness_layers : 0, sky->reflection.layers.size() > 1); + Vector data = RD::get_singleton()->texture_get_data(rad_tex, 0); + RD::get_singleton()->free(rad_tex); + + Ref img; + img.instance(); + img->create(p_size.width, p_size.height, false, Image::FORMAT_RGBAF, data); + for (int i = 0; i < p_size.width; i++) { + for (int j = 0; j < p_size.height; j++) { + Color c = img->get_pixel(i, j); + c.r *= p_energy; + c.g *= p_energy; + c.b *= p_energy; + img->set_pixel(i, j, c); + } + } + return img; + } + + return Ref(); +} + void RasterizerSceneRD::_update_dirty_skys() { Sky *sky = dirty_sky_list; @@ -1336,6 +1376,43 @@ bool RasterizerSceneRD::is_environment(RID p_env) const { return environment_owner.owns(p_env); } +Ref RasterizerSceneRD::environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size) { + Environent *env = environment_owner.getornull(p_env); + ERR_FAIL_COND_V(!env, Ref()); + + if (env->background == RS::ENV_BG_CAMERA_FEED || env->background == RS::ENV_BG_CANVAS || env->background == RS::ENV_BG_KEEP) { + return Ref(); //nothing to bake + } + + if (env->background == RS::ENV_BG_CLEAR_COLOR || env->background == RS::ENV_BG_COLOR) { + Color color; + if (env->background == RS::ENV_BG_CLEAR_COLOR) { + color = storage->get_default_clear_color(); + } else { + color = env->bg_color; + } + color.r *= env->bg_energy; + color.g *= env->bg_energy; + color.b *= env->bg_energy; + + Ref ret; + ret.instance(); + ret->create(p_size.width, p_size.height, false, Image::FORMAT_RGBAF); + for (int i = 0; i < p_size.width; i++) { + for (int j = 0; j < p_size.height; j++) { + ret->set_pixel(i, j, color); + } + } + return ret; + } + + if (env->background == RS::ENV_BG_SKY && env->sky.is_valid()) { + return sky_bake_panorama(env->sky, env->bg_energy, p_bake_irradiance, p_size); + } + + return Ref(); +} + //////////////////////////////////////////////////////////// RID RasterizerSceneRD::reflection_atlas_create() { @@ -3741,7 +3818,7 @@ RasterizerSceneRD::RenderBufferData *RasterizerSceneRD::render_buffers_get_data( return rb->data; } -void RasterizerSceneRD::render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) { +void RasterizerSceneRD::render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) { Color clear_color; if (p_render_buffers.is_valid()) { @@ -3752,7 +3829,7 @@ void RasterizerSceneRD::render_scene(RID p_render_buffers, const Transform &p_ca clear_color = storage->get_default_clear_color(); } - _render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_ortogonal, p_cull_result, p_cull_count, p_light_cull_result, p_light_cull_count, p_reflection_probe_cull_result, p_reflection_probe_cull_count, p_gi_probe_cull_result, p_gi_probe_cull_count, p_decal_cull_result, p_decal_cull_count, p_environment, p_camera_effects, p_shadow_atlas, p_reflection_atlas, p_reflection_probe, p_reflection_probe_pass, clear_color); + _render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_ortogonal, p_cull_result, p_cull_count, p_light_cull_result, p_light_cull_count, p_reflection_probe_cull_result, p_reflection_probe_cull_count, p_gi_probe_cull_result, p_gi_probe_cull_count, p_decal_cull_result, p_decal_cull_count, p_lightmap_cull_result, p_lightmap_cull_count, p_environment, p_camera_effects, p_shadow_atlas, p_reflection_atlas, p_reflection_probe, p_reflection_probe_pass, clear_color); if (p_render_buffers.is_valid()) { RENDER_TIMESTAMP("Tonemap"); @@ -4079,6 +4156,98 @@ float RasterizerSceneRD::screen_space_roughness_limiter_get_curve() const { return screen_space_roughness_limiter_curve; } +TypedArray RasterizerSceneRD::bake_render_uv2(RID p_base, const Vector &p_material_overrides, const Size2i &p_image_size) { + + RD::TextureFormat tf; + tf.format = RD::DATA_FORMAT_R8G8B8A8_UNORM; + tf.width = p_image_size.width; // Always 64x64 + tf.height = p_image_size.height; + tf.usage_bits = RD::TEXTURE_USAGE_COLOR_ATTACHMENT_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + + RID albedo_alpha_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + RID normal_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + RID orm_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + tf.format = RD::DATA_FORMAT_R16G16B16A16_SFLOAT; + RID emission_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + tf.format = RD::DATA_FORMAT_R32_SFLOAT; + RID depth_write_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + tf.usage_bits = RD::TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + tf.format = RD::get_singleton()->texture_is_format_supported_for_usage(RD::DATA_FORMAT_D32_SFLOAT, RD::TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT) ? RD::DATA_FORMAT_D32_SFLOAT : RD::DATA_FORMAT_X8_D24_UNORM_PACK32; + RID depth_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + Vector fb_tex; + fb_tex.push_back(albedo_alpha_tex); + fb_tex.push_back(normal_tex); + fb_tex.push_back(orm_tex); + fb_tex.push_back(emission_tex); + fb_tex.push_back(depth_write_tex); + fb_tex.push_back(depth_tex); + + RID fb = RD::get_singleton()->framebuffer_create(fb_tex); + + //RID sampled_light; + + InstanceBase ins; + + ins.base_type = RSG::storage->get_base_type(p_base); + ins.base = p_base; + ins.materials.resize(RSG::storage->mesh_get_surface_count(p_base)); + for (int i = 0; i < ins.materials.size(); i++) { + if (i < p_material_overrides.size()) { + ins.materials.write[i] = p_material_overrides[i]; + } + } + + InstanceBase *cull = &ins; + _render_uv2(&cull, 1, fb, Rect2i(0, 0, p_image_size.width, p_image_size.height)); + + TypedArray ret; + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(albedo_alpha_tex, 0); + Ref img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBA8, data); + RD::get_singleton()->free(albedo_alpha_tex); + ret.push_back(img); + } + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(normal_tex, 0); + Ref img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBA8, data); + RD::get_singleton()->free(normal_tex); + ret.push_back(img); + } + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(orm_tex, 0); + Ref img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBA8, data); + RD::get_singleton()->free(orm_tex); + ret.push_back(img); + } + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(emission_tex, 0); + Ref img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBAH, data); + RD::get_singleton()->free(emission_tex); + ret.push_back(img); + } + + RD::get_singleton()->free(depth_write_tex); + RD::get_singleton()->free(depth_tex); + + return ret; +} + RasterizerSceneRD *RasterizerSceneRD::singleton = nullptr; RasterizerSceneRD::RasterizerSceneRD(RasterizerStorageRD *p_storage) { diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h index a511838e16..5aaa15f441 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h @@ -80,9 +80,10 @@ protected: }; virtual RenderBufferData *_create_render_buffer_data() = 0; - virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_color) = 0; + virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_color) = 0; virtual void _render_shadow(RID p_framebuffer, InstanceBase **p_cull_result, int p_cull_count, const CameraMatrix &p_projection, const Transform &p_transform, float p_zfar, float p_bias, float p_normal_bias, bool p_use_dp, bool use_dp_flip, bool p_use_pancake) = 0; virtual void _render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) = 0; + virtual void _render_uv2(InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) = 0; virtual void _debug_giprobe(RID p_gi_probe, RenderingDevice::DrawListID p_draw_list, RID p_framebuffer, const CameraMatrix &p_camera_with_transform, bool p_lighting, bool p_emission, float p_alpha); @@ -843,6 +844,7 @@ public: void sky_set_radiance_size(RID p_sky, int p_radiance_size); void sky_set_mode(RID p_sky, RS::SkyMode p_mode); void sky_set_material(RID p_sky, RID p_material); + Ref sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size); RID sky_get_radiance_texture_rd(RID p_sky) const; RID sky_get_radiance_uniform_set_rd(RID p_sky, RID p_shader, int p_set) const; @@ -900,6 +902,8 @@ public: void environment_set_fog_depth(RID p_env, bool p_enable, float p_depth_begin, float p_depth_end, float p_depth_curve, bool p_transmit, float p_transmit_curve) {} void environment_set_fog_height(RID p_env, bool p_enable, float p_min_height, float p_max_height, float p_height_curve) {} + virtual Ref environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size); + virtual RID camera_effects_create(); virtual void camera_effects_set_dof_blur_quality(RS::DOFBlurQuality p_quality, bool p_use_jitter); @@ -1194,7 +1198,7 @@ public: RID render_buffers_get_ao_texture(RID p_render_buffers); RID render_buffers_get_back_buffer_texture(RID p_render_buffers); - void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_shadow_atlas, RID p_camera_effects, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass); + void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_shadow_atlas, RID p_camera_effects, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass); void render_shadow(RID p_light, RID p_shadow_atlas, int p_pass, InstanceBase **p_cull_result, int p_cull_count); @@ -1235,6 +1239,8 @@ public: int get_roughness_layers() const; bool is_using_radiance_cubemap_array() const; + virtual TypedArray bake_render_uv2(RID p_base, const Vector &p_material_overrides, const Size2i &p_image_size); + virtual bool free(RID p_rid); virtual void update(); diff --git a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp index 8d299d623a..0203293a76 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp @@ -610,7 +610,113 @@ RID RasterizerStorageRD::texture_2d_create(const Ref &p_image) { RID RasterizerStorageRD::texture_2d_layered_create(const Vector> &p_layers, RS::TextureLayeredType p_layered_type) { - return RID(); + ERR_FAIL_COND_V(p_layers.size() == 0, RID()); + + ERR_FAIL_COND_V(p_layered_type == RS::TEXTURE_LAYERED_CUBEMAP && p_layers.size() != 6, RID()); + ERR_FAIL_COND_V(p_layered_type == RS::TEXTURE_LAYERED_CUBEMAP_ARRAY && (p_layers.size() < 6 || (p_layers.size() % 6) != 0), RID()); + + TextureToRDFormat ret_format; + Vector> images; + { + int valid_width = 0; + int valid_height = 0; + bool valid_mipmaps = false; + Image::Format valid_format = Image::FORMAT_MAX; + + for (int i = 0; i < p_layers.size(); i++) { + ERR_FAIL_COND_V(p_layers[i]->empty(), RID()); + + if (i == 0) { + valid_width = p_layers[i]->get_width(); + valid_height = p_layers[i]->get_height(); + valid_format = p_layers[i]->get_format(); + valid_mipmaps = p_layers[i]->has_mipmaps(); + } else { + ERR_FAIL_COND_V(p_layers[i]->get_width() != valid_width, RID()); + ERR_FAIL_COND_V(p_layers[i]->get_height() != valid_height, RID()); + ERR_FAIL_COND_V(p_layers[i]->get_format() != valid_format, RID()); + ERR_FAIL_COND_V(p_layers[i]->has_mipmaps() != valid_mipmaps, RID()); + } + + images.push_back(_validate_texture_format(p_layers[i], ret_format)); + } + } + + Texture texture; + + texture.type = Texture::TYPE_LAYERED; + texture.layered_type = p_layered_type; + + texture.width = p_layers[0]->get_width(); + texture.height = p_layers[0]->get_height(); + texture.layers = p_layers.size(); + texture.mipmaps = p_layers[0]->get_mipmap_count() + 1; + texture.depth = 1; + texture.format = p_layers[0]->get_format(); + texture.validated_format = images[0]->get_format(); + + switch (p_layered_type) { + case RS::TEXTURE_LAYERED_2D_ARRAY: { + texture.rd_type = RD::TEXTURE_TYPE_2D_ARRAY; + } break; + case RS::TEXTURE_LAYERED_CUBEMAP: { + texture.rd_type = RD::TEXTURE_TYPE_CUBE; + } break; + case RS::TEXTURE_LAYERED_CUBEMAP_ARRAY: { + texture.rd_type = RD::TEXTURE_TYPE_CUBE_ARRAY; + } break; + } + + texture.rd_format = ret_format.format; + texture.rd_format_srgb = ret_format.format_srgb; + + RD::TextureFormat rd_format; + RD::TextureView rd_view; + { //attempt register + rd_format.format = texture.rd_format; + rd_format.width = texture.width; + rd_format.height = texture.height; + rd_format.depth = 1; + rd_format.array_layers = texture.layers; + rd_format.mipmaps = texture.mipmaps; + rd_format.type = texture.rd_type; + rd_format.samples = RD::TEXTURE_SAMPLES_1; + rd_format.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + if (texture.rd_format_srgb != RD::DATA_FORMAT_MAX) { + rd_format.shareable_formats.push_back(texture.rd_format); + rd_format.shareable_formats.push_back(texture.rd_format_srgb); + } + } + { + rd_view.swizzle_r = ret_format.swizzle_r; + rd_view.swizzle_g = ret_format.swizzle_g; + rd_view.swizzle_b = ret_format.swizzle_b; + rd_view.swizzle_a = ret_format.swizzle_a; + } + Vector> data_slices; + for (int i = 0; i < images.size(); i++) { + Vector data = images[i]->get_data(); //use image data + data_slices.push_back(data); + } + texture.rd_texture = RD::get_singleton()->texture_create(rd_format, rd_view, data_slices); + ERR_FAIL_COND_V(texture.rd_texture.is_null(), RID()); + if (texture.rd_format_srgb != RD::DATA_FORMAT_MAX) { + rd_view.format_override = texture.rd_format_srgb; + texture.rd_texture_srgb = RD::get_singleton()->texture_create_shared(rd_view, texture.rd_texture); + if (texture.rd_texture_srgb.is_null()) { + RD::get_singleton()->free(texture.rd_texture); + ERR_FAIL_COND_V(texture.rd_texture_srgb.is_null(), RID()); + } + } + + //used for 2D, overridable + texture.width_2d = texture.width; + texture.height_2d = texture.height; + texture.is_render_target = false; + texture.rd_view = rd_view; + texture.is_proxy = false; + + return texture_owner.make_rid(texture); } RID RasterizerStorageRD::texture_3d_create(const Vector> &p_slices) { @@ -729,9 +835,31 @@ RID RasterizerStorageRD::texture_2d_placeholder_create() { return texture_2d_create(image); } -RID RasterizerStorageRD::texture_2d_layered_placeholder_create() { +RID RasterizerStorageRD::texture_2d_layered_placeholder_create(RS::TextureLayeredType p_layered_type) { - return RID(); + //this could be better optimized to reuse an existing image , done this way + //for now to get it working + Ref image; + image.instance(); + image->create(4, 4, false, Image::FORMAT_RGBA8); + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + image->set_pixel(i, j, Color(1, 0, 1, 1)); + } + } + + Vector> images; + if (p_layered_type == RS::TEXTURE_LAYERED_2D_ARRAY) { + images.push_back(image); + } else { + //cube + for (int i = 0; i < 6; i++) { + images.push_back(image); + } + } + + return texture_2d_layered_create(images, p_layered_type); } RID RasterizerStorageRD::texture_3d_placeholder_create() { @@ -4139,6 +4267,180 @@ RID RasterizerStorageRD::gi_probe_get_sdf_texture(RID p_gi_probe) { return gi_probe->sdf_texture; } +/* LIGHTMAP API */ + +RID RasterizerStorageRD::lightmap_create() { + return lightmap_owner.make_rid(Lightmap()); +} + +void RasterizerStorageRD::lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics) { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + + lightmap_array_version++; + + //erase lightmap users + if (lm->light_texture.is_valid()) { + Texture *t = texture_owner.getornull(lm->light_texture); + if (t) { + t->lightmap_users.erase(p_lightmap); + } + } + + Texture *t = texture_owner.getornull(p_light); + lm->light_texture = p_light; + lm->uses_spherical_harmonics = p_uses_spherical_haromics; + + RID default_2d_array = default_rd_textures[DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE]; + if (!t) { + + if (using_lightmap_array) { + if (lm->array_index >= 0) { + lightmap_textures.write[lm->array_index] = default_2d_array; + lm->array_index = -1; + } + } + + return; + } + + t->lightmap_users.insert(p_lightmap); + + if (using_lightmap_array) { + if (lm->array_index < 0) { + //not in array, try to put in array + for (int i = 0; i < lightmap_textures.size(); i++) { + if (lightmap_textures[i] == default_2d_array) { + lm->array_index = i; + break; + } + } + } + ERR_FAIL_COND_MSG(lm->array_index < 0, "Maximum amount of lightmaps in use (" + itos(lightmap_textures.size()) + ") has been exceeded, lightmap will nod display properly."); + + lightmap_textures.write[lm->array_index] = t->rd_texture; + } +} + +void RasterizerStorageRD::lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds) { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + lm->bounds = p_bounds; +} + +void RasterizerStorageRD::lightmap_set_probe_interior(RID p_lightmap, bool p_interior) { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + lm->interior = p_interior; +} + +void RasterizerStorageRD::lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + + if (p_points.size()) { + ERR_FAIL_COND(p_points.size() * 9 != p_point_sh.size()); + ERR_FAIL_COND((p_tetrahedra.size() % 4) != 0); + ERR_FAIL_COND((p_bsp_tree.size() % 6) != 0); + } + + lm->points = p_points; + lm->bsp_tree = p_bsp_tree; + lm->point_sh = p_point_sh; + lm->tetrahedra = p_tetrahedra; +} + +PackedVector3Array RasterizerStorageRD::lightmap_get_probe_capture_points(RID p_lightmap) const { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedVector3Array()); + + return lm->points; +} +PackedColorArray RasterizerStorageRD::lightmap_get_probe_capture_sh(RID p_lightmap) const { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedColorArray()); + return lm->point_sh; +} +PackedInt32Array RasterizerStorageRD::lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedInt32Array()); + return lm->tetrahedra; +} +PackedInt32Array RasterizerStorageRD::lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedInt32Array()); + return lm->bsp_tree; +} + +void RasterizerStorageRD::lightmap_set_probe_capture_update_speed(float p_speed) { + lightmap_probe_capture_update_speed = p_speed; +} + +void RasterizerStorageRD::lightmap_tap_sh_light(RID p_lightmap, const Vector3 &p_point, Color *r_sh) { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + + for (int i = 0; i < 9; i++) { + r_sh[i] = Color(0, 0, 0, 0); + } + + if (!lm->points.size() || !lm->bsp_tree.size() || !lm->tetrahedra.size()) { + return; + } + + static_assert(sizeof(Lightmap::BSP) == 24); + + const Lightmap::BSP *bsp = (const Lightmap::BSP *)lm->bsp_tree.ptr(); + int32_t node = 0; + while (node >= 0) { + + if (Plane(bsp[node].plane[0], bsp[node].plane[1], bsp[node].plane[2], bsp[node].plane[3]).is_point_over(p_point)) { +#ifdef DEBUG_ENABLED + ERR_FAIL_COND(bsp[node].over >= 0 && bsp[node].over < node); +#endif + + node = bsp[node].over; + } else { +#ifdef DEBUG_ENABLED + ERR_FAIL_COND(bsp[node].under >= 0 && bsp[node].under < node); +#endif + node = bsp[node].under; + } + } + + if (node == Lightmap::BSP::EMPTY_LEAF) { + return; //nothing could be done + } + + node = ABS(node) - 1; + + uint32_t *tetrahedron = (uint32_t *)&lm->tetrahedra[node * 4]; + Vector3 points[4] = { lm->points[tetrahedron[0]], lm->points[tetrahedron[1]], lm->points[tetrahedron[2]], lm->points[tetrahedron[3]] }; + const Color *sh_colors[4]{ &lm->point_sh[tetrahedron[0] * 9], &lm->point_sh[tetrahedron[1] * 9], &lm->point_sh[tetrahedron[2] * 9], &lm->point_sh[tetrahedron[3] * 9] }; + Color barycentric = Geometry::tetrahedron_get_barycentric_coords(points[0], points[1], points[2], points[3], p_point); + + for (int i = 0; i < 4; i++) { + float c = CLAMP(barycentric[i], 0.0, 1.0); + for (int j = 0; j < 9; j++) { + r_sh[j] += sh_colors[i][j] * c; + } + } +} + +bool RasterizerStorageRD::lightmap_is_interior(RID p_lightmap) const { + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, false); + return lm->interior; +} +AABB RasterizerStorageRD::lightmap_get_aabb(RID p_lightmap) const { + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, AABB()); + return lm->bounds; +} /* RENDER TARGET API */ @@ -4491,6 +4793,9 @@ void RasterizerStorageRD::base_update_dependency(RID p_base, RasterizerScene::In } else if (gi_probe_owner.owns(p_base)) { GIProbe *gip = gi_probe_owner.getornull(p_base); p_instance->update_dependency(&gip->instance_dependency); + } else if (lightmap_owner.owns(p_base)) { + Lightmap *lm = lightmap_owner.getornull(p_base); + p_instance->update_dependency(&lm->instance_dependency); } else if (light_owner.owns(p_base)) { Light *l = light_owner.getornull(p_base); p_instance->update_dependency(&l->instance_dependency); @@ -4525,6 +4830,9 @@ RS::InstanceType RasterizerStorageRD::get_base_type(RID p_rid) const { if (light_owner.owns(p_rid)) { return RS::INSTANCE_LIGHT; } + if (lightmap_owner.owns(p_rid)) { + return RS::INSTANCE_LIGHTMAP; + } return RS::INSTANCE_NONE; } @@ -4678,7 +4986,7 @@ void RasterizerStorageRD::_update_decal_atlas() { DecalAtlas::Texture *t = decal_atlas.textures.getptr(items[i].texture); t->uv_rect.position = items[i].pos * border + Vector2i(border / 2, border / 2); t->uv_rect.size = items[i].pixel_size; - //print_line("blitrect: " + t->uv_rect); + t->uv_rect.position /= Size2(decal_atlas.size); t->uv_rect.size /= Size2(decal_atlas.size); } @@ -5563,6 +5871,11 @@ bool RasterizerStorageRD::free(RID p_rid) { GIProbe *gi_probe = gi_probe_owner.getornull(p_rid); gi_probe->instance_dependency.instance_notify_deleted(p_rid); gi_probe_owner.free(p_rid); + } else if (lightmap_owner.owns(p_rid)) { + lightmap_set_textures(p_rid, RID(), false); + Lightmap *lightmap = lightmap_owner.getornull(p_rid); + lightmap->instance_dependency.instance_notify_deleted(p_rid); + lightmap_owner.free(p_rid); } else if (light_owner.owns(p_rid)) { @@ -5801,6 +6114,32 @@ RasterizerStorageRD::RasterizerStorageRD() { } } + { //create default array + + RD::TextureFormat tformat; + tformat.format = RD::DATA_FORMAT_R8G8B8A8_UNORM; + tformat.width = 4; + tformat.height = 4; + tformat.array_layers = 1; + tformat.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + tformat.type = RD::TEXTURE_TYPE_2D_ARRAY; + + Vector pv; + pv.resize(16 * 4); + for (int i = 0; i < 16; i++) { + pv.set(i * 4 + 0, 255); + pv.set(i * 4 + 1, 255); + pv.set(i * 4 + 2, 255); + pv.set(i * 4 + 3, 255); + } + + { + Vector> vpv; + vpv.push_back(pv); + default_rd_textures[DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE] = RD::get_singleton()->texture_create(tformat, RD::TextureView(), vpv); + } + } + //default samplers for (int i = 1; i < RS::CANVAS_ITEM_TEXTURE_FILTER_MAX; i++) { for (int j = 1; j < RS::CANVAS_ITEM_TEXTURE_REPEAT_MAX; j++) { @@ -5872,124 +6211,133 @@ RasterizerStorageRD::RasterizerStorageRD() { //default rd buffers { - //vertex + Vector buffer; { - Vector buffer; + buffer.resize(sizeof(float) * 3); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_VERTEX] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } - buffer.resize(sizeof(float) * 3); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_VERTEX] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //normal + buffer.resize(sizeof(float) * 3); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 1.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_NORMAL] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //normal - Vector buffer; - buffer.resize(sizeof(float) * 3); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 1.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_NORMAL] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //tangent + buffer.resize(sizeof(float) * 4); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 1.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + fptr[3] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TANGENT] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //tangent - Vector buffer; - buffer.resize(sizeof(float) * 4); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 1.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - fptr[3] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TANGENT] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //color + buffer.resize(sizeof(float) * 4); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 1.0; + fptr[1] = 1.0; + fptr[2] = 1.0; + fptr[3] = 1.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_COLOR] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //color - Vector buffer; - buffer.resize(sizeof(float) * 4); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 1.0; - fptr[1] = 1.0; - fptr[2] = 1.0; - fptr[3] = 1.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_COLOR] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //tex uv 1 + buffer.resize(sizeof(float) * 2); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } + { //tex uv 2 + buffer.resize(sizeof(float) * 2); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV2] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //tex uv 1 - Vector buffer; - buffer.resize(sizeof(float) * 2); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} -{ //tex uv 2 - Vector buffer; - buffer.resize(sizeof(float) * 2); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV2] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //bones + buffer.resize(sizeof(uint32_t) * 4); + { + uint8_t *w = buffer.ptrw(); + uint32_t *fptr = (uint32_t *)w; + fptr[0] = 0; + fptr[1] = 0; + fptr[2] = 0; + fptr[3] = 0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_BONES] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //bones - Vector buffer; - buffer.resize(sizeof(uint32_t) * 4); - { - uint8_t *w = buffer.ptrw(); - uint32_t *fptr = (uint32_t *)w; - fptr[0] = 0; - fptr[1] = 0; - fptr[2] = 0; - fptr[3] = 0; + { //weights + buffer.resize(sizeof(float) * 4); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + fptr[3] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_WEIGHTS] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_BONES] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} -{ //weights - Vector buffer; - buffer.resize(sizeof(float) * 4); { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - fptr[3] = 0.0; + Vector sdf_versions; + sdf_versions.push_back(""); //one only + giprobe_sdf_shader.initialize(sdf_versions); + giprobe_sdf_shader_version = giprobe_sdf_shader.version_create(); + giprobe_sdf_shader.version_set_compute_code(giprobe_sdf_shader_version, "", "", "", Vector()); + giprobe_sdf_shader_version_shader = giprobe_sdf_shader.version_get_shader(giprobe_sdf_shader_version, 0); + giprobe_sdf_shader_pipeline = RD::get_singleton()->compute_pipeline_create(giprobe_sdf_shader_version_shader); } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_WEIGHTS] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} -} -{ - Vector sdf_versions; - sdf_versions.push_back(""); //one only - giprobe_sdf_shader.initialize(sdf_versions); - giprobe_sdf_shader_version = giprobe_sdf_shader.version_create(); - giprobe_sdf_shader.version_set_compute_code(giprobe_sdf_shader_version, "", "", "", Vector()); - giprobe_sdf_shader_version_shader = giprobe_sdf_shader.version_get_shader(giprobe_sdf_shader_version, 0); - giprobe_sdf_shader_pipeline = RD::get_singleton()->compute_pipeline_create(giprobe_sdf_shader_version_shader); -} + using_lightmap_array = true; // high end + if (using_lightmap_array) { + + uint32_t textures_per_stage = RD::get_singleton()->limit_get(RD::LIMIT_MAX_TEXTURES_PER_SHADER_STAGE); + + if (textures_per_stage <= 256) { + lightmap_textures.resize(32); + } else { + lightmap_textures.resize(1024); + } + + for (int i = 0; i < lightmap_textures.size(); i++) { + lightmap_textures.write[i] = default_rd_textures[DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE]; + } + } + + lightmap_probe_capture_update_speed = GLOBAL_GET("rendering/lightmapper/probe_capture_update_speed"); } RasterizerStorageRD::~RasterizerStorageRD() { diff --git a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h index f874c3baf8..94b373247f 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h @@ -92,6 +92,7 @@ public: DEFAULT_RD_TEXTURE_CUBEMAP_BLACK, DEFAULT_RD_TEXTURE_CUBEMAP_ARRAY_BLACK, DEFAULT_RD_TEXTURE_3D_WHITE, + DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE, DEFAULT_RD_TEXTURE_MAX }; @@ -118,6 +119,7 @@ private: }; Type type; + RS::TextureLayeredType layered_type = RS::TEXTURE_LAYERED_2D_ARRAY; RenderingDevice::TextureType rd_type; RID rd_texture; @@ -147,6 +149,7 @@ private: RID proxy_to; Vector proxies; + Set lightmap_users; RS::TextureDetectCallback detect_3d_callback = nullptr; void *detect_3d_callback_ud = nullptr; @@ -524,6 +527,40 @@ private: mutable RID_Owner gi_probe_owner; + /* REFLECTION PROBE */ + + struct Lightmap { + + RID light_texture; + bool uses_spherical_harmonics = false; + bool interior = false; + AABB bounds = AABB(Vector3(), Vector3(1, 1, 1)); + int32_t array_index = -1; //unassigned + PackedVector3Array points; + PackedColorArray point_sh; + PackedInt32Array tetrahedra; + PackedInt32Array bsp_tree; + + struct BSP { + static const int32_t EMPTY_LEAF = INT32_MIN; + float plane[4]; + int32_t over = EMPTY_LEAF, under = EMPTY_LEAF; + }; + + RasterizerScene::InstanceDependency instance_dependency; + }; + + bool using_lightmap_array; //high end uses this + /* for high end */ + + Vector lightmap_textures; + + uint64_t lightmap_array_version = 0; + + mutable RID_Owner lightmap_owner; + + float lightmap_probe_capture_update_speed = 4; + /* RENDER TARGET */ struct RenderTarget { @@ -653,7 +690,7 @@ public: //these two APIs can be used together or in combination with the others. virtual RID texture_2d_placeholder_create(); - virtual RID texture_2d_layered_placeholder_create(); + virtual RID texture_2d_layered_placeholder_create(RenderingServer::TextureLayeredType p_layered_type); virtual RID texture_3d_placeholder_create(); virtual Ref texture_2d_get(RID p_texture) const; @@ -1270,23 +1307,47 @@ public: /* LIGHTMAP CAPTURE */ - void lightmap_capture_set_bounds(RID p_capture, const AABB &p_bounds) {} - AABB lightmap_capture_get_bounds(RID p_capture) const { return AABB(); } - void lightmap_capture_set_octree(RID p_capture, const Vector &p_octree) {} - RID lightmap_capture_create() { - return RID(); + virtual RID lightmap_create(); + + virtual void lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics); + virtual void lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds); + virtual void lightmap_set_probe_interior(RID p_lightmap, bool p_interior); + virtual void lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree); + virtual PackedVector3Array lightmap_get_probe_capture_points(RID p_lightmap) const; + virtual PackedColorArray lightmap_get_probe_capture_sh(RID p_lightmap) const; + virtual PackedInt32Array lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const; + virtual PackedInt32Array lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const; + virtual AABB lightmap_get_aabb(RID p_lightmap) const; + virtual bool lightmap_is_interior(RID p_lightmap) const; + virtual void lightmap_tap_sh_light(RID p_lightmap, const Vector3 &p_point, Color *r_sh); + virtual void lightmap_set_probe_capture_update_speed(float p_speed); + _FORCE_INLINE_ float lightmap_get_probe_capture_update_speed() const { + return lightmap_probe_capture_update_speed; } - Vector lightmap_capture_get_octree(RID p_capture) const { - return Vector(); + + _FORCE_INLINE_ int32_t lightmap_get_array_index(RID p_lightmap) const { + ERR_FAIL_COND_V(!using_lightmap_array, -1); //only for arrays + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + return lm->array_index; } - void lightmap_capture_set_octree_cell_transform(RID p_capture, const Transform &p_xform) {} - Transform lightmap_capture_get_octree_cell_transform(RID p_capture) const { return Transform(); } - void lightmap_capture_set_octree_cell_subdiv(RID p_capture, int p_subdiv) {} - int lightmap_capture_get_octree_cell_subdiv(RID p_capture) const { return 0; } - void lightmap_capture_set_energy(RID p_capture, float p_energy) {} - float lightmap_capture_get_energy(RID p_capture) const { return 0.0; } - const Vector *lightmap_capture_get_octree_ptr(RID p_capture) const { - return nullptr; + _FORCE_INLINE_ bool lightmap_uses_spherical_harmonics(RID p_lightmap) const { + ERR_FAIL_COND_V(!using_lightmap_array, false); //only for arrays + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + return lm->uses_spherical_harmonics; + } + _FORCE_INLINE_ uint64_t lightmap_array_get_version() const { + ERR_FAIL_COND_V(!using_lightmap_array, 0); //only for arrays + return lightmap_array_version; + } + + _FORCE_INLINE_ int lightmap_array_get_size() const { + ERR_FAIL_COND_V(!using_lightmap_array, 0); //only for arrays + return lightmap_textures.size(); + } + + _FORCE_INLINE_ const Vector &lightmap_array_get_textures() const { + ERR_FAIL_COND_V(!using_lightmap_array, lightmap_textures); //only for arrays + return lightmap_textures; } /* PARTICLES */ diff --git a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp index 2bfdb7fffe..5838936f35 100644 --- a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp +++ b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp @@ -31,16 +31,20 @@ #include "render_pipeline_vertex_format_cache_rd.h" #include "core/os/memory.h" -RID RenderPipelineVertexFormatCacheRD::_generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id) { +RID RenderPipelineVertexFormatCacheRD::_generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id, bool p_wireframe) { RD::PipelineMultisampleState multisample_state_version = multisample_state; multisample_state_version.sample_count = RD::get_singleton()->framebuffer_format_get_texture_samples(p_framebuffer_format_id); - RID pipeline = RD::get_singleton()->render_pipeline_create(shader, p_framebuffer_format_id, p_vertex_format_id, render_primitive, rasterization_state, multisample_state_version, depth_stencil_state, blend_state, dynamic_state_flags); + RD::PipelineRasterizationState raster_state_version = rasterization_state; + raster_state_version.wireframe = p_wireframe; + + RID pipeline = RD::get_singleton()->render_pipeline_create(shader, p_framebuffer_format_id, p_vertex_format_id, render_primitive, raster_state_version, multisample_state_version, depth_stencil_state, blend_state, dynamic_state_flags); ERR_FAIL_COND_V(pipeline.is_null(), RID()); versions = (Version *)memrealloc(versions, sizeof(Version) * (version_count + 1)); versions[version_count].framebuffer_id = p_framebuffer_format_id; versions[version_count].vertex_id = p_vertex_format_id; + versions[version_count].wireframe = p_wireframe; versions[version_count].pipeline = pipeline; version_count++; return pipeline; diff --git a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h index ecb1b42b06..a8bfdb5a26 100644 --- a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h +++ b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h @@ -51,13 +51,14 @@ class RenderPipelineVertexFormatCacheRD { struct Version { RD::VertexFormatID vertex_id; RD::FramebufferFormatID framebuffer_id; + bool wireframe; RID pipeline; }; Version *versions; uint32_t version_count; - RID _generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id); + RID _generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id, bool p_wireframe); void _clear(); @@ -65,7 +66,7 @@ public: void setup(RID p_shader, RD::RenderPrimitive p_primitive, const RD::PipelineRasterizationState &p_rasterization_state, RD::PipelineMultisampleState p_multisample, const RD::PipelineDepthStencilState &p_depth_stencil_state, const RD::PipelineColorBlendState &p_blend_state, int p_dynamic_state_flags = 0); void update_shader(RID p_shader); - _FORCE_INLINE_ RID get_render_pipeline(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id) { + _FORCE_INLINE_ RID get_render_pipeline(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id, bool p_wireframe = false) { #ifdef DEBUG_ENABLED ERR_FAIL_COND_V_MSG(shader.is_null(), RID(), "Attempted to use an unused shader variant (shader is null),"); @@ -74,13 +75,13 @@ public: spin_lock.lock(); RID result; for (uint32_t i = 0; i < version_count; i++) { - if (versions[i].vertex_id == p_vertex_format_id && versions[i].framebuffer_id == p_framebuffer_format_id) { + if (versions[i].vertex_id == p_vertex_format_id && versions[i].framebuffer_id == p_framebuffer_format_id && versions[i].wireframe == p_wireframe) { result = versions[i].pipeline; spin_lock.unlock(); return result; } } - result = _generate_version(p_vertex_format_id, p_framebuffer_format_id); + result = _generate_version(p_vertex_format_id, p_framebuffer_format_id, p_wireframe); spin_lock.unlock(); return result; } diff --git a/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp b/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp index 25856c92c7..2ef29e97ff 100644 --- a/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp +++ b/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp @@ -120,8 +120,11 @@ static int _get_datatype_size(SL::DataType p_type) { return 16; case SL::TYPE_SAMPLERCUBE: return 16; + case SL::TYPE_SAMPLERCUBEARRAY: + return 16; case SL::TYPE_STRUCT: return 0; + case SL::TYPE_MAX: { ERR_FAIL_V(0); }; @@ -194,6 +197,8 @@ static int _get_datatype_alignment(SL::DataType p_type) { return 16; case SL::TYPE_SAMPLERCUBE: return 16; + case SL::TYPE_SAMPLERCUBEARRAY: + return 16; case SL::TYPE_STRUCT: return 0; case SL::TYPE_MAX: { diff --git a/servers/rendering/rasterizer_rd/shaders/copy.glsl b/servers/rendering/rasterizer_rd/shaders/copy.glsl index 2d7661f65f..075ee2af22 100644 --- a/servers/rendering/rasterizer_rd/shaders/copy.glsl +++ b/servers/rendering/rasterizer_rd/shaders/copy.glsl @@ -39,7 +39,13 @@ layout(push_constant, binding = 1, std430) uniform Params { } params; +#ifdef MODE_CUBEMAP_ARRAY_TO_PANORAMA +layout(set = 0, binding = 0) uniform samplerCubeArray source_color; +#elif defined(MODE_CUBEMAP_TO_PANORAMA) +layout(set = 0, binding = 0) uniform samplerCube source_color; +#else layout(set = 0, binding = 0) uniform sampler2D source_color; +#endif #ifdef GLOW_USE_AUTO_EXPOSURE layout(set = 1, binding = 0) uniform sampler2D source_auto_exposure; @@ -57,7 +63,7 @@ void main() { // Pixel being shaded ivec2 pos = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(pos, params.section.zw))) { //too large, do nothing + if (any(greaterThanEqual(pos, params.section.zw))) { //too large, do nothing return; } @@ -217,4 +223,25 @@ void main() { imageStore(dest_buffer, pos + params.target, color); #endif + +#if defined(MODE_CUBEMAP_TO_PANORAMA) || defined(MODE_CUBEMAP_ARRAY_TO_PANORAMA) + + const float PI = 3.14159265359; + vec2 uv = vec2(pos) / vec2(params.section.zw); + uv.y = 1.0 - uv.y; + float phi = uv.x * 2.0 * PI; + float theta = uv.y * PI; + + vec3 normal; + normal.x = sin(phi) * sin(theta) * -1.0; + normal.y = cos(theta); + normal.z = cos(phi) * sin(theta) * -1.0; + +#ifdef MODE_CUBEMAP_TO_PANORAMA + vec4 color = textureLod(source_color, normal, params.camera_z_far); //the biggest the lod the least the acne +#else + vec4 color = textureLod(source_color, vec4(normal, params.camera_z_far), 0.0); //the biggest the lod the least the acne +#endif + imageStore(dest_buffer, pos + params.target, color); +#endif } diff --git a/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl b/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl index 4eba5d41d8..90e37b3ec4 100644 --- a/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl +++ b/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl @@ -22,7 +22,7 @@ layout(location = 3) in vec4 color_attrib; layout(location = 4) in vec2 uv_attrib; -#if defined(UV2_USED) || defined(USE_LIGHTMAP) +#if defined(UV2_USED) || defined(USE_LIGHTMAP) || defined(MODE_RENDER_MATERIAL) layout(location = 5) in vec2 uv2_attrib; #endif @@ -49,7 +49,7 @@ layout(location = 6) out vec3 binormal_interp; #endif #ifdef USE_MATERIAL_UNIFORMS -layout(set = 5, binding = 0, std140) uniform MaterialUniforms{ +layout(set = MATERIAL_UNIFORM_SET, binding = 0, std140) uniform MaterialUniforms{ /* clang-format off */ MATERIAL_UNIFORMS /* clang-format on */ @@ -263,6 +263,14 @@ VERTEX_SHADER_CODE } } #endif + +#ifdef MODE_RENDER_MATERIAL + if (scene_data.material_uv2_mode) { + gl_Position.xy = (uv2_attrib.xy + draw_call.bake_uv2_offset) * 2.0 - 1.0; + gl_Position.z = 0.00001; + gl_Position.w = 1.0; + } +#endif } /* clang-format off */ @@ -315,7 +323,7 @@ layout(location = 8) in float dp_clip; #endif #ifdef USE_MATERIAL_UNIFORMS -layout(set = 5, binding = 0, std140) uniform MaterialUniforms{ +layout(set = MATERIAL_UNIFORM_SET, binding = 0, std140) uniform MaterialUniforms{ /* clang-format off */ MATERIAL_UNIFORMS /* clang-format on */ @@ -1917,42 +1925,96 @@ FRAGMENT_SHADER_CODE #if !defined(MODE_RENDER_DEPTH) && !defined(MODE_UNSHADED) //gi probes - //lightmap +#ifdef USE_LIGHTMAP + //lightmap + if (bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_LIGHTMAP_CAPTURE)) { //has lightmap capture + uint index = instances.data[instance_index].gi_offset; + + vec3 wnormal = mat3(scene_data.camera_matrix) * normal; + const float c1 = 0.429043; + const float c2 = 0.511664; + const float c3 = 0.743125; + const float c4 = 0.886227; + const float c5 = 0.247708; + ambient_light += (c1 * lightmap_captures.data[index].sh[8].rgb * (wnormal.x * wnormal.x - wnormal.y * wnormal.y) + + c3 * lightmap_captures.data[index].sh[6].rgb * wnormal.z * wnormal.z + + c4 * lightmap_captures.data[index].sh[0].rgb - + c5 * lightmap_captures.data[index].sh[6].rgb + + 2.0 * c1 * lightmap_captures.data[index].sh[4].rgb * wnormal.x * wnormal.y + + 2.0 * c1 * lightmap_captures.data[index].sh[7].rgb * wnormal.x * wnormal.z + + 2.0 * c1 * lightmap_captures.data[index].sh[5].rgb * wnormal.y * wnormal.z + + 2.0 * c2 * lightmap_captures.data[index].sh[3].rgb * wnormal.x + + 2.0 * c2 * lightmap_captures.data[index].sh[1].rgb * wnormal.y + + 2.0 * c2 * lightmap_captures.data[index].sh[2].rgb * wnormal.z); + + } else if (bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_LIGHTMAP)) { // has actual lightmap + bool uses_sh = bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_SH_LIGHTMAP); + uint ofs = instances.data[instance_index].gi_offset & 0xFFF; + vec3 uvw; + uvw.xy = uv2 * instances.data[instance_index].lightmap_uv_scale.zw + instances.data[instance_index].lightmap_uv_scale.xy; + uvw.z = float((instances.data[instance_index].gi_offset >> 12) & 0xFF); + + if (uses_sh) { + uvw.z *= 4.0; //SH textures use 4 times more data + vec3 lm_light_l0 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 0.0), 0.0).rgb; + vec3 lm_light_l1n1 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 1.0), 0.0).rgb; + vec3 lm_light_l1_0 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 2.0), 0.0).rgb; + vec3 lm_light_l1p1 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 3.0), 0.0).rgb; + + uint idx = instances.data[instance_index].gi_offset >> 20; + vec3 n = normalize(lightmaps.data[idx].normal_xform * normal); + + ambient_light += lm_light_l0 * 0.282095f; + ambient_light += lm_light_l1n1 * 0.32573 * n.y; + ambient_light += lm_light_l1_0 * 0.32573 * n.z; + ambient_light += lm_light_l1p1 * 0.32573 * n.x; + if (metallic > 0.01) { // since the more direct bounced light is lost, we can kind of fake it with this trick + vec3 r = reflect(normalize(-vertex), normal); + specular_light += lm_light_l1n1 * 0.32573 * r.y; + specular_light += lm_light_l1_0 * 0.32573 * r.z; + specular_light += lm_light_l1p1 * 0.32573 * r.x; + } + + } else { + + ambient_light += textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw, 0.0).rgb; + } + } +#endif //lightmap capture #ifdef USE_VOXEL_CONE_TRACING - { // process giprobes + if (bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_GIPROBE)) { // process giprobes + uint index1 = instances.data[instance_index].gi_offset & 0xFFFF; - if (index1 != 0xFFFF) { - vec3 ref_vec = normalize(reflect(normalize(vertex), normal)); - //find arbitrary tangent and bitangent, then build a matrix - vec3 v0 = abs(normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); - vec3 tangent = normalize(cross(v0, normal)); - vec3 bitangent = normalize(cross(tangent, normal)); - mat3 normal_mat = mat3(tangent, bitangent, normal); + vec3 ref_vec = normalize(reflect(normalize(vertex), normal)); + //find arbitrary tangent and bitangent, then build a matrix + vec3 v0 = abs(normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); + vec3 tangent = normalize(cross(v0, normal)); + vec3 bitangent = normalize(cross(tangent, normal)); + mat3 normal_mat = mat3(tangent, bitangent, normal); - vec4 amb_accum = vec4(0.0); - vec4 spec_accum = vec4(0.0); - gi_probe_compute(index1, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); + vec4 amb_accum = vec4(0.0); + vec4 spec_accum = vec4(0.0); + gi_probe_compute(index1, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); - uint index2 = instances.data[instance_index].gi_offset >> 16; + uint index2 = instances.data[instance_index].gi_offset >> 16; - if (index2 != 0xFFFF) { - gi_probe_compute(index2, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); - } - - if (amb_accum.a > 0.0) { - amb_accum.rgb /= amb_accum.a; - } - - if (spec_accum.a > 0.0) { - spec_accum.rgb /= spec_accum.a; - } - - specular_light = spec_accum.rgb; - ambient_light = amb_accum.rgb; + if (index2 != 0xFFFF) { + gi_probe_compute(index2, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); } + + if (amb_accum.a > 0.0) { + amb_accum.rgb /= amb_accum.a; + } + + if (spec_accum.a > 0.0) { + spec_accum.rgb /= spec_accum.a; + } + + specular_light = spec_accum.rgb; + ambient_light = amb_accum.rgb; } #endif @@ -2424,7 +2486,6 @@ FRAGMENT_SHADER_CODE ao_light_affect = mix(1.0, ao, ao_light_affect); specular_light = mix(scene_data.ao_color.rgb, specular_light, ao_light_affect); diffuse_light = mix(scene_data.ao_color.rgb, diffuse_light, ao_light_affect); - #else if (scene_data.ssao_enabled) { diff --git a/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl b/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl index ce4fabf9f2..89706b74d6 100644 --- a/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl +++ b/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl @@ -3,7 +3,8 @@ layout(push_constant, binding = 0, std430) uniform DrawCall { uint instance_index; - uint pad[3]; //16 bits minimum size + uint pad; //16 bits minimum size + vec2 bake_uv2_offset; //used for bake to uv2, ignored otherwise } draw_call; @@ -77,6 +78,10 @@ layout(set = 0, binding = 3, std140) uniform SceneData { bool roughness_limiter_enabled; vec4 ao_color; + bool material_uv2_mode; + uint pad_material0; + uint pad_material1; + uint pad_material2; #if 0 vec4 ambient_light_color; @@ -115,11 +120,10 @@ layout(set = 0, binding = 3, std140) uniform SceneData { } scene_data; -#define INSTANCE_FLAGS_FORWARD_MASK 0x7 -#define INSTANCE_FLAGS_FORWARD_OMNI_LIGHT_SHIFT 3 -#define INSTANCE_FLAGS_FORWARD_SPOT_LIGHT_SHIFT 6 -#define INSTANCE_FLAGS_FORWARD_DECAL_SHIFT 9 - +#define INSTANCE_FLAGS_USE_LIGHTMAP_CAPTURE (1 << 8) +#define INSTANCE_FLAGS_USE_LIGHTMAP (1 << 9) +#define INSTANCE_FLAGS_USE_SH_LIGHTMAP (1 << 10) +#define INSTANCE_FLAGS_USE_GIPROBE (1 << 11) #define INSTANCE_FLAGS_MULTIMESH (1 << 12) #define INSTANCE_FLAGS_MULTIMESH_FORMAT_2D (1 << 13) #define INSTANCE_FLAGS_MULTIMESH_HAS_COLOR (1 << 14) @@ -135,8 +139,9 @@ struct InstanceData { mat4 normal_transform; uint flags; uint instance_uniforms_ofs; //base offset in global buffer for instance variables - uint gi_offset; //GI information when using lightmapping (VCT or lightmap) + uint gi_offset; //GI information when using lightmapping (VCT or lightmap index) uint layer_mask; + vec4 lightmap_uv_scale; }; layout(set = 0, binding = 4, std430) restrict readonly buffer Instances { @@ -248,12 +253,35 @@ gi_probes; layout(set = 0, binding = 9) uniform texture3D gi_probe_textures[MAX_GI_PROBE_TEXTURES]; +#define LIGHTMAP_FLAG_USE_DIRECTION 1 +#define LIGHTMAP_FLAG_USE_SPECULAR_DIRECTION 2 + +struct Lightmap { + mat3 normal_xform; +}; + +layout(set = 0, binding = 10, std140) restrict readonly buffer Lightmaps { + Lightmap data[]; +} +lightmaps; + +layout(set = 0, binding = 11) uniform texture2DArray lightmap_textures[MAX_LIGHTMAP_TEXTURES]; + +struct LightmapCapture { + vec4 sh[9]; +}; + +layout(set = 0, binding = 12, std140) restrict readonly buffer LightmapCaptures { + LightmapCapture data[]; +} +lightmap_captures; + #define CLUSTER_COUNTER_SHIFT 20 #define CLUSTER_POINTER_MASK ((1 << CLUSTER_COUNTER_SHIFT) - 1) #define CLUSTER_COUNTER_MASK 0xfff -layout(set = 0, binding = 10) uniform texture2D decal_atlas; -layout(set = 0, binding = 11) uniform texture2D decal_atlas_srgb; +layout(set = 0, binding = 13) uniform texture2D decal_atlas; +layout(set = 0, binding = 14) uniform texture2D decal_atlas_srgb; struct DecalData { mat4 xform; //to decal transform @@ -273,21 +301,21 @@ struct DecalData { float normal_fade; }; -layout(set = 0, binding = 12, std430) restrict readonly buffer Decals { +layout(set = 0, binding = 15, std430) restrict readonly buffer Decals { DecalData data[]; } decals; -layout(set = 0, binding = 13) uniform utexture3D cluster_texture; +layout(set = 0, binding = 16) uniform utexture3D cluster_texture; -layout(set = 0, binding = 14, std430) restrict readonly buffer ClusterData { +layout(set = 0, binding = 17, std430) restrict readonly buffer ClusterData { uint indices[]; } cluster_data; -layout(set = 0, binding = 15) uniform texture2D directional_shadow_atlas; +layout(set = 0, binding = 18) uniform texture2D directional_shadow_atlas; -layout(set = 0, binding = 16, std430) restrict readonly buffer GlobalVariableData { +layout(set = 0, binding = 19, std430) restrict readonly buffer GlobalVariableData { vec4 data[]; } global_variables; @@ -312,7 +340,7 @@ layout(set = 2, binding = 0) uniform textureCubeArray reflection_atlas; layout(set = 2, binding = 1) uniform texture2D shadow_atlas; -/* Set 1, Render Buffers */ +/* Set 3, Render Buffers */ layout(set = 3, binding = 0) uniform texture2D depth_buffer; layout(set = 3, binding = 1) uniform texture2D color_buffer; diff --git a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl index e3c26c9b72..11a0d85c58 100644 --- a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl +++ b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl @@ -68,7 +68,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl index 1a5dd5ab55..8571d9d6d1 100644 --- a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl +++ b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl @@ -120,7 +120,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl index cec6c14c76..f2c3230679 100644 --- a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl +++ b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl @@ -34,7 +34,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } //do not filter, SSR will generate arctifacts if this is done diff --git a/servers/rendering/rasterizer_rd/shaders/ssao.glsl b/servers/rendering/rasterizer_rd/shaders/ssao.glsl index c9d7134610..0175e26b85 100644 --- a/servers/rendering/rasterizer_rd/shaders/ssao.glsl +++ b/servers/rendering/rasterizer_rd/shaders/ssao.glsl @@ -212,7 +212,7 @@ float sampleAO(in ivec2 ssC, in vec3 C, in vec3 n_C, in float ssDiskRadius, in f void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl b/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl index e90c788e08..877e5d50fe 100644 --- a/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl +++ b/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl @@ -49,7 +49,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl b/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl index 41f8fde3ca..4cb486a499 100644 --- a/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl +++ b/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl @@ -142,7 +142,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rendering_device.cpp b/servers/rendering/rendering_device.cpp index a3bb39cd90..aeac6f2eff 100644 --- a/servers/rendering/rendering_device.cpp +++ b/servers/rendering/rendering_device.cpp @@ -147,7 +147,7 @@ Ref RenderingDevice::_shader_compile_from_source(const Ref &p_bytecode) { +RID RenderingDevice::shader_create_from_bytecode(const Ref &p_bytecode) { ERR_FAIL_COND_V(p_bytecode.is_null(), RID()); Vector stage_data; @@ -276,7 +276,7 @@ void RenderingDevice::_bind_methods() { ClassDB::bind_method(D_METHOD("index_array_create", "index_buffer", "index_offset", "index_count"), &RenderingDevice::index_array_create); ClassDB::bind_method(D_METHOD("shader_compile_from_source", "shader_source", "allow_cache"), &RenderingDevice::_shader_compile_from_source, DEFVAL(true)); - ClassDB::bind_method(D_METHOD("shader_create", "shader_data"), &RenderingDevice::_shader_create); + ClassDB::bind_method(D_METHOD("shader_create", "shader_data"), &RenderingDevice::shader_create_from_bytecode); ClassDB::bind_method(D_METHOD("shader_get_vertex_input_attribute_mask", "shader"), &RenderingDevice::shader_get_vertex_input_attribute_mask); ClassDB::bind_method(D_METHOD("uniform_buffer_create", "size_bytes", "data"), &RenderingDevice::uniform_buffer_create, DEFVAL(Vector())); diff --git a/servers/rendering/rendering_device.h b/servers/rendering/rendering_device.h index c76fce5b5c..c7d0a1cdd2 100644 --- a/servers/rendering/rendering_device.h +++ b/servers/rendering/rendering_device.h @@ -596,6 +596,7 @@ public: } }; + RID shader_create_from_bytecode(const Ref &p_bytecode); virtual RID shader_create(const Vector &p_stages) = 0; virtual uint32_t shader_get_vertex_input_attribute_mask(RID p_shader) = 0; @@ -1045,6 +1046,8 @@ public: virtual void submit() = 0; virtual void sync() = 0; + virtual uint64_t get_memory_usage() const = 0; + virtual RenderingDevice *create_local_device() = 0; static RenderingDevice *get_singleton(); @@ -1063,7 +1066,6 @@ protected: RID _vertex_array_create(uint32_t p_vertex_count, VertexFormatID p_vertex_format, const TypedArray &p_src_buffers); Ref _shader_compile_from_source(const Ref &p_source, bool p_allow_cache = true); - RID _shader_create(const Ref &p_bytecode); RID _uniform_set_create(const Array &p_uniforms, RID p_shader, uint32_t p_shader_set); diff --git a/servers/rendering/rendering_device_binds.cpp b/servers/rendering/rendering_device_binds.cpp index 43121e2cb5..91076a538e 100644 --- a/servers/rendering/rendering_device_binds.cpp +++ b/servers/rendering/rendering_device_binds.cpp @@ -30,7 +30,7 @@ #include "rendering_device_binds.h" -Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFunction p_include_func, void *p_include_func_userdata) { +Error RDShaderFile::parse_versions_from_text(const String &p_text, const String p_defines, OpenIncludeFunction p_include_func, void *p_include_func_userdata) { Vector lines = p_text.split("\n"); @@ -56,6 +56,9 @@ Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFu { String ls = line.strip_edges(); + if (ls.begins_with("#[")) { //workaround for clang format + ls = ls.replace_first("#[", "["); + } if (ls.begins_with("[") && ls.ends_with("]")) { String section = ls.substr(1, ls.length() - 2).strip_edges(); if (section == "versions") { @@ -90,9 +93,17 @@ Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFu } } + if (stage == RD::SHADER_STAGE_MAX && line.strip_edges() != "") { + line = line.strip_edges(); + if (line.begins_with("//") || line.begins_with("/*")) { + continue; //assuming comment (single line) + } + } + if (reading_versions) { String l = line.strip_edges(); if (l != "") { + int eqpos = l.find("="); if (eqpos == -1) { base_error = "Version syntax is version=\"\"."; @@ -110,7 +121,7 @@ Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFu } define = "\n" + define.substr(1, define.length() - 2).c_unescape() + "\n"; //add newline before and after jsut in case - version_texts[version] = define; + version_texts[version] = define + "\n" + p_defines; } } else { if (stage == RD::SHADER_STAGE_MAX && line.strip_edges() != "") { diff --git a/servers/rendering/rendering_device_binds.h b/servers/rendering/rendering_device_binds.h index 2a5396a3e8..fe8d554594 100644 --- a/servers/rendering/rendering_device_binds.h +++ b/servers/rendering/rendering_device_binds.h @@ -322,8 +322,31 @@ public: return base_error; } + void print_errors(const String &p_file) { + if (base_error != "") { + ERR_PRINT("Error parsing shader '" + p_file + "':\n\n" + base_error); + } else { + for (Map>::Element *E = versions.front(); E; E = E->next()) { + for (int i = 0; i < RD::SHADER_STAGE_MAX; i++) { + String error = E->get()->get_stage_compile_error(RD::ShaderStage(i)); + if (error != String()) { + static const char *stage_str[RD::SHADER_STAGE_MAX] = { + "vertex", + "fragment", + "tesselation_control", + "tesselation_evaluation", + "compute" + }; + + ERR_PRINT("Error parsing shader '" + p_file + "', version '" + String(E->key()) + "', stage '" + stage_str[i] + "':\n\n" + error); + } + } + } + } + } + typedef String (*OpenIncludeFunction)(const String &, void *userdata); - Error parse_versions_from_text(const String &p_text, OpenIncludeFunction p_include_func = nullptr, void *p_include_func_userdata = nullptr); + Error parse_versions_from_text(const String &p_text, const String p_defines = String(), OpenIncludeFunction p_include_func = nullptr, void *p_include_func_userdata = nullptr); protected: Dictionary _get_versions() const { diff --git a/servers/rendering/rendering_server_raster.h b/servers/rendering/rendering_server_raster.h index f7b963a015..5dd146861d 100644 --- a/servers/rendering/rendering_server_raster.h +++ b/servers/rendering/rendering_server_raster.h @@ -108,8 +108,12 @@ public: m_r m_name(m_type1 arg1, m_type2 arg2) { return BINDBASE->m_name(arg1, arg2); } #define BIND2RC(m_r, m_name, m_type1, m_type2) \ m_r m_name(m_type1 arg1, m_type2 arg2) const { return BINDBASE->m_name(arg1, arg2); } +#define BIND3R(m_r, m_name, m_type1, m_type2, m_type3) \ + m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3) { return BINDBASE->m_name(arg1, arg2, arg3); } #define BIND3RC(m_r, m_name, m_type1, m_type2, m_type3) \ m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3) const { return BINDBASE->m_name(arg1, arg2, arg3); } +#define BIND4R(m_r, m_name, m_type1, m_type2, m_type3, m_type4) \ + m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3, m_type4 arg4) { return BINDBASE->m_name(arg1, arg2, arg3, arg4); } #define BIND4RC(m_r, m_name, m_type1, m_type2, m_type3, m_type4) \ m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3, m_type4 arg4) const { return BINDBASE->m_name(arg1, arg2, arg3, arg4); } @@ -170,7 +174,7 @@ public: //these also go pass-through BIND0R(RID, texture_2d_placeholder_create) - BIND0R(RID, texture_2d_layered_placeholder_create) + BIND1R(RID, texture_2d_layered_placeholder_create, TextureLayeredType) BIND0R(RID, texture_3d_placeholder_create) BIND1RC(Ref, texture_2d_get, RID) @@ -404,23 +408,19 @@ public: BIND2(gi_probe_set_anisotropy_strength, RID, float) BIND1RC(float, gi_probe_get_anisotropy_strength, RID) - /* LIGHTMAP CAPTURE */ + /* LIGHTMAP */ - BIND0R(RID, lightmap_capture_create) + BIND0R(RID, lightmap_create) - BIND2(lightmap_capture_set_bounds, RID, const AABB &) - BIND1RC(AABB, lightmap_capture_get_bounds, RID) - - BIND2(lightmap_capture_set_octree, RID, const Vector &) - BIND1RC(Vector, lightmap_capture_get_octree, RID) - - BIND2(lightmap_capture_set_octree_cell_transform, RID, const Transform &) - BIND1RC(Transform, lightmap_capture_get_octree_cell_transform, RID) - BIND2(lightmap_capture_set_octree_cell_subdiv, RID, int) - BIND1RC(int, lightmap_capture_get_octree_cell_subdiv, RID) - - BIND2(lightmap_capture_set_energy, RID, float) - BIND1RC(float, lightmap_capture_get_energy, RID) + BIND3(lightmap_set_textures, RID, RID, bool) + BIND2(lightmap_set_probe_bounds, RID, const AABB &) + BIND2(lightmap_set_probe_interior, RID, bool) + BIND5(lightmap_set_probe_capture_data, RID, const PackedVector3Array &, const PackedColorArray &, const PackedInt32Array &, const PackedInt32Array &) + BIND1RC(PackedVector3Array, lightmap_get_probe_capture_points, RID) + BIND1RC(PackedColorArray, lightmap_get_probe_capture_sh, RID) + BIND1RC(PackedInt32Array, lightmap_get_probe_capture_tetrahedra, RID) + BIND1RC(PackedInt32Array, lightmap_get_probe_capture_bsp_tree, RID) + BIND1(lightmap_set_probe_capture_update_speed, float) /* PARTICLES */ @@ -532,6 +532,7 @@ public: BIND2(sky_set_radiance_size, RID, int) BIND2(sky_set_mode, RID, SkyMode) BIND2(sky_set_material, RID, RID) + BIND4R(Ref, sky_bake_panorama, RID, float, bool, const Size2i &) BIND0R(RID, environment_create) @@ -565,6 +566,8 @@ public: BIND7(environment_set_fog_depth, RID, bool, float, float, float, bool, float) BIND5(environment_set_fog_height, RID, bool, float, float, float) + BIND3R(Ref, environment_bake_panorama, RID, bool, const Size2i &) + BIND2(screen_space_roughness_limiter_set_active, bool, float) BIND1(sub_surface_scattering_set_quality, SubSurfaceScatteringQuality) BIND2(sub_surface_scattering_set_scale, float, float) @@ -605,7 +608,6 @@ public: BIND3(instance_set_blend_shape_weight, RID, int, float) BIND3(instance_set_surface_material, RID, int, RID) BIND2(instance_set_visible, RID, bool) - BIND3(instance_set_use_lightmap, RID, RID, RID) BIND2(instance_set_custom_aabb, RID, AABB) @@ -625,12 +627,15 @@ public: BIND5(instance_geometry_set_draw_range, RID, float, float, float, float) BIND2(instance_geometry_set_as_instance_lod, RID, RID) + BIND4(instance_geometry_set_lightmap, RID, RID, const Rect2 &, int) BIND3(instance_geometry_set_shader_parameter, RID, const StringName &, const Variant &) BIND2RC(Variant, instance_geometry_get_shader_parameter, RID, const StringName &) BIND2RC(Variant, instance_geometry_get_shader_parameter_default_value, RID, const StringName &) BIND2C(instance_geometry_get_shader_parameter_list, RID, List *) + BIND3R(TypedArray, bake_render_uv2, RID, const Vector &, const Size2i &) + #undef BINDBASE //from now on, calls forwarded to this singleton #define BINDBASE RSG::canvas diff --git a/servers/rendering/rendering_server_scene.cpp b/servers/rendering/rendering_server_scene.cpp index 4f338ee2a5..95334ee102 100644 --- a/servers/rendering/rendering_server_scene.cpp +++ b/servers/rendering/rendering_server_scene.cpp @@ -169,19 +169,22 @@ void *RenderingServerScene::_instance_pair(void *p_self, OctreeElementID, Instan geom->decal_dirty = true; return E; //this element should make freeing faster - } else if (B->base_type == RS::INSTANCE_LIGHTMAP_CAPTURE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { + } else if (B->base_type == RS::INSTANCE_LIGHTMAP && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { - InstanceLightmapCaptureData *lightmap_capture = static_cast(B->base_data); + InstanceLightmapData *lightmap_data = static_cast(B->base_data); InstanceGeometryData *geom = static_cast(A->base_data); - InstanceLightmapCaptureData::PairInfo pinfo; - pinfo.geometry = A; - pinfo.L = geom->lightmap_captures.push_back(B); + if (A->dynamic_gi) { + InstanceLightmapData::PairInfo pinfo; + pinfo.geometry = A; + pinfo.L = geom->lightmap_captures.push_back(B); + List::Element *E = lightmap_data->geometries.push_back(pinfo); + ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture + return E; //this element should make freeing faster + } else { + return nullptr; + } - List::Element *E = lightmap_capture->geometries.push_back(pinfo); - ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture - - return E; //this element should make freeing faster } else if (B->base_type == RS::INSTANCE_GI_PROBE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { InstanceGIProbeData *gi_probe = static_cast(B->base_data); @@ -258,16 +261,18 @@ void RenderingServerScene::_instance_unpair(void *p_self, OctreeElementID, Insta decal->geometries.erase(E); geom->decal_dirty = true; - } else if (B->base_type == RS::INSTANCE_LIGHTMAP_CAPTURE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { + } else if (B->base_type == RS::INSTANCE_LIGHTMAP && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { - InstanceLightmapCaptureData *lightmap_capture = static_cast(B->base_data); - InstanceGeometryData *geom = static_cast(A->base_data); + if (udata) { //only for dynamic geometries + InstanceLightmapData *lightmap_data = static_cast(B->base_data); + InstanceGeometryData *geom = static_cast(A->base_data); - List::Element *E = reinterpret_cast::Element *>(udata); + List::Element *E = reinterpret_cast::Element *>(udata); - geom->lightmap_captures.erase(E->get().L); - lightmap_capture->geometries.erase(E); - ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture + geom->lightmap_captures.erase(E->get().L); + lightmap_data->geometries.erase(E); + ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture + } } else if (B->base_type == RS::INSTANCE_GI_PROBE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { @@ -418,12 +423,12 @@ void RenderingServerScene::instance_set_base(RID p_instance, RID p_base) { RSG::scene_render->free(decal->instance); } break; - case RS::INSTANCE_LIGHTMAP_CAPTURE: { + case RS::INSTANCE_LIGHTMAP: { - InstanceLightmapCaptureData *lightmap_capture = static_cast(instance->base_data); + InstanceLightmapData *lightmap_data = static_cast(instance->base_data); //erase dependencies, since no longer a lightmap - while (lightmap_capture->users.front()) { - instance_set_use_lightmap(lightmap_capture->users.front()->get()->self, RID(), RID()); + while (lightmap_data->users.front()) { + instance_geometry_set_lightmap(lightmap_data->users.front()->get()->self, RID(), Rect2(), 0); } } break; case RS::INSTANCE_GI_PROBE: { @@ -443,14 +448,6 @@ void RenderingServerScene::instance_set_base(RID p_instance, RID p_base) { gi_probe_update_list.remove(&gi_probe->update_element); } - if (instance->lightmap_capture) { - Instance *capture = (Instance *)instance->lightmap_capture; - InstanceLightmapCaptureData *lightmap_capture = static_cast(capture->base_data); - lightmap_capture->users.erase(instance); - instance->lightmap_capture = nullptr; - instance->lightmap = RID(); - } - RSG::scene_render->free(gi_probe->probe_instance); } break; @@ -515,11 +512,11 @@ void RenderingServerScene::instance_set_base(RID p_instance, RID p_base) { decal->instance = RSG::scene_render->decal_instance_create(p_base); } break; - case RS::INSTANCE_LIGHTMAP_CAPTURE: { + case RS::INSTANCE_LIGHTMAP: { - InstanceLightmapCaptureData *lightmap_capture = memnew(InstanceLightmapCaptureData); - instance->base_data = lightmap_capture; - //lightmap_capture->instance = RSG::scene_render->lightmap_capture_instance_create(p_base); + InstanceLightmapData *lightmap_data = memnew(InstanceLightmapData); + instance->base_data = lightmap_data; + //lightmap_data->instance = RSG::scene_render->lightmap_data_instance_create(p_base); } break; case RS::INSTANCE_GI_PROBE: { @@ -736,9 +733,9 @@ void RenderingServerScene::instance_set_visible(RID p_instance, bool p_visible) } } break; - case RS::INSTANCE_LIGHTMAP_CAPTURE: { + case RS::INSTANCE_LIGHTMAP: { if (instance->octree_id && instance->scenario) { - instance->scenario->octree.set_pairable(instance->octree_id, p_visible, 1 << RS::INSTANCE_LIGHTMAP_CAPTURE, p_visible ? RS::INSTANCE_GEOMETRY_MASK : 0); + instance->scenario->octree.set_pairable(instance->octree_id, p_visible, 1 << RS::INSTANCE_LIGHTMAP, p_visible ? RS::INSTANCE_GEOMETRY_MASK : 0); } } break; @@ -756,30 +753,6 @@ inline bool is_geometry_instance(RenderingServer::InstanceType p_type) { return p_type == RS::INSTANCE_MESH || p_type == RS::INSTANCE_MULTIMESH || p_type == RS::INSTANCE_PARTICLES || p_type == RS::INSTANCE_IMMEDIATE; } -void RenderingServerScene::instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap) { - - Instance *instance = instance_owner.getornull(p_instance); - ERR_FAIL_COND(!instance); - - if (instance->lightmap_capture) { - InstanceLightmapCaptureData *lightmap_capture = static_cast(((Instance *)instance->lightmap_capture)->base_data); - lightmap_capture->users.erase(instance); - instance->lightmap = RID(); - instance->lightmap_capture = nullptr; - } - - if (p_lightmap_instance.is_valid()) { - Instance *lightmap_instance = instance_owner.getornull(p_lightmap_instance); - ERR_FAIL_COND(!lightmap_instance); - ERR_FAIL_COND(lightmap_instance->base_type != RS::INSTANCE_LIGHTMAP_CAPTURE); - instance->lightmap_capture = lightmap_instance; - - InstanceLightmapCaptureData *lightmap_capture = static_cast(((Instance *)instance->lightmap_capture)->base_data); - lightmap_capture->users.insert(instance); - instance->lightmap = p_lightmap; - } -} - void RenderingServerScene::instance_set_custom_aabb(RID p_instance, AABB p_aabb) { Instance *instance = instance_owner.getornull(p_instance); @@ -968,6 +941,29 @@ void RenderingServerScene::instance_geometry_set_draw_range(RID p_instance, floa void RenderingServerScene::instance_geometry_set_as_instance_lod(RID p_instance, RID p_as_lod_of_instance) { } +void RenderingServerScene::instance_geometry_set_lightmap(RID p_instance, RID p_lightmap, const Rect2 &p_lightmap_uv_scale, int p_slice_index) { + + Instance *instance = instance_owner.getornull(p_instance); + ERR_FAIL_COND(!instance); + + if (instance->lightmap) { + InstanceLightmapData *lightmap_data = static_cast(((Instance *)instance->lightmap)->base_data); + lightmap_data->users.erase(instance); + instance->lightmap = nullptr; + } + + Instance *lightmap_instance = instance_owner.getornull(p_lightmap); + + instance->lightmap = lightmap_instance; + instance->lightmap_uv_scale = p_lightmap_uv_scale; + instance->lightmap_slice_index = p_slice_index; + + if (lightmap_instance) { + InstanceLightmapData *lightmap_data = static_cast(lightmap_instance->base_data); + lightmap_data->users.insert(instance); + } +} + void RenderingServerScene::instance_geometry_set_shader_parameter(RID p_instance, const StringName &p_parameter, const Variant &p_value) { Instance *instance = instance_owner.getornull(p_instance); @@ -1084,16 +1080,29 @@ void RenderingServerScene::_update_instance(Instance *p_instance) { } } - if (!p_instance->lightmap_capture && geom->lightmap_captures.size()) { + if (!p_instance->lightmap && geom->lightmap_captures.size()) { //affected by lightmap captures, must update capture info! _update_instance_lightmap_captures(p_instance); } else { - if (!p_instance->lightmap_capture_data.empty()) { - p_instance->lightmap_capture_data.resize(0); //not in use, clear capture data + if (!p_instance->lightmap_sh.empty()) { + p_instance->lightmap_sh.clear(); //don't need SH + p_instance->lightmap_target_sh.clear(); //don't need SH } } } + if (p_instance->base_type == RS::INSTANCE_LIGHTMAP) { + + //if this moved, update the captured objects + InstanceLightmapData *lightmap_data = static_cast(p_instance->base_data); + //erase dependencies, since no longer a lightmap + + for (List::Element *E = lightmap_data->geometries.front(); E; E = E->next()) { + Instance *geom = E->get().geometry; + _instance_queue_update(geom, true, false); + } + } + p_instance->mirror = p_instance->transform.basis.determinant() < 0.0; AABB new_aabb; @@ -1113,7 +1122,7 @@ void RenderingServerScene::_update_instance(Instance *p_instance) { uint32_t pairable_mask = 0; bool pairable = false; - if (p_instance->base_type == RS::INSTANCE_LIGHT || p_instance->base_type == RS::INSTANCE_REFLECTION_PROBE || p_instance->base_type == RS::INSTANCE_DECAL || p_instance->base_type == RS::INSTANCE_LIGHTMAP_CAPTURE) { + if (p_instance->base_type == RS::INSTANCE_LIGHT || p_instance->base_type == RS::INSTANCE_REFLECTION_PROBE || p_instance->base_type == RS::INSTANCE_DECAL || p_instance->base_type == RS::INSTANCE_LIGHTMAP) { pairable_mask = p_instance->visible ? RS::INSTANCE_GEOMETRY_MASK : 0; pairable = true; @@ -1203,9 +1212,9 @@ void RenderingServerScene::_update_instance_aabb(Instance *p_instance) { new_aabb = RSG::storage->gi_probe_get_bounds(p_instance->base); } break; - case RenderingServer::INSTANCE_LIGHTMAP_CAPTURE: { + case RenderingServer::INSTANCE_LIGHTMAP: { - new_aabb = RSG::storage->lightmap_capture_get_bounds(p_instance->base); + new_aabb = RSG::storage->lightmap_get_aabb(p_instance->base); } break; default: { @@ -1219,235 +1228,82 @@ void RenderingServerScene::_update_instance_aabb(Instance *p_instance) { p_instance->aabb = new_aabb; } -_FORCE_INLINE_ static void _light_capture_sample_octree(const RasterizerStorage::LightmapCaptureOctree *p_octree, int p_cell_subdiv, const Vector3 &p_pos, const Vector3 &p_dir, float p_level, Vector3 &r_color, float &r_alpha) { +void RenderingServerScene::_update_instance_lightmap_captures(Instance *p_instance) { - static const Vector3 aniso_normal[6] = { - Vector3(-1, 0, 0), - Vector3(1, 0, 0), - Vector3(0, -1, 0), - Vector3(0, 1, 0), - Vector3(0, 0, -1), - Vector3(0, 0, 1) - }; + bool first_set = p_instance->lightmap_sh.size() == 0; + p_instance->lightmap_sh.resize(9); //using SH + p_instance->lightmap_target_sh.resize(9); //using SH + Color *instance_sh = p_instance->lightmap_target_sh.ptrw(); + bool inside = false; + Color accum_sh[9]; + float accum_blend = 0.0; - int size = 1 << (p_cell_subdiv - 1); + InstanceGeometryData *geom = static_cast(p_instance->base_data); + for (List::Element *E = geom->lightmap_captures.front(); E; E = E->next()) { + Instance *lightmap = E->get(); - int clamp_v = size - 1; - //first of all, clamp - Vector3 pos; - pos.x = CLAMP(p_pos.x, 0, clamp_v); - pos.y = CLAMP(p_pos.y, 0, clamp_v); - pos.z = CLAMP(p_pos.z, 0, clamp_v); + bool interior = RSG::storage->lightmap_is_interior(lightmap->base); - float level = (p_cell_subdiv - 1) - p_level; + if (inside && !interior) { + continue; //we are inside, ignore exteriors + } - int target_level; - float level_filter; - if (level <= 0.0) { - level_filter = 0; - target_level = 0; - } else { - target_level = Math::ceil(level); - level_filter = target_level - level; - } + Transform to_bounds = lightmap->transform.affine_inverse(); + Vector3 center = p_instance->transform.xform(p_instance->aabb.position + p_instance->aabb.size * 0.5); //use aabb center - Vector3 color[2][8]; - float alpha[2][8]; - zeromem(alpha, sizeof(float) * 2 * 8); + Vector3 lm_pos = to_bounds.xform(center); - //find cell at given level first + AABB bounds = RSG::storage->lightmap_get_aabb(lightmap->base); + if (!bounds.has_point(lm_pos)) { + continue; //not in this lightmap + } - for (int c = 0; c < 2; c++) { + Color sh[9]; + RSG::storage->lightmap_tap_sh_light(lightmap->base, lm_pos, sh); - int current_level = MAX(0, target_level - c); - int level_cell_size = (1 << (p_cell_subdiv - 1)) >> current_level; - - for (int n = 0; n < 8; n++) { - - int x = int(pos.x); - int y = int(pos.y); - int z = int(pos.z); - - if (n & 1) - x += level_cell_size; - if (n & 2) - y += level_cell_size; - if (n & 4) - z += level_cell_size; - - int ofs_x = 0; - int ofs_y = 0; - int ofs_z = 0; - - x = CLAMP(x, 0, clamp_v); - y = CLAMP(y, 0, clamp_v); - z = CLAMP(z, 0, clamp_v); - - int half = size / 2; - uint32_t cell = 0; - for (int i = 0; i < current_level; i++) { - - const RasterizerStorage::LightmapCaptureOctree *bc = &p_octree[cell]; - - int child = 0; - if (x >= ofs_x + half) { - child |= 1; - ofs_x += half; - } - if (y >= ofs_y + half) { - child |= 2; - ofs_y += half; - } - if (z >= ofs_z + half) { - child |= 4; - ofs_z += half; - } - - cell = bc->children[child]; - if (cell == RasterizerStorage::LightmapCaptureOctree::CHILD_EMPTY) - break; - - half >>= 1; + //rotate it + Basis rot = lightmap->transform.basis.orthonormalized(); + for (int i = 0; i < 3; i++) { + float csh[9]; + for (int j = 0; j < 9; j++) { + csh[j] = sh[j][i]; } - - if (cell == RasterizerStorage::LightmapCaptureOctree::CHILD_EMPTY) { - alpha[c][n] = 0; - } else { - alpha[c][n] = p_octree[cell].alpha; - - for (int i = 0; i < 6; i++) { - //anisotropic read light - float amount = p_dir.dot(aniso_normal[i]); - if (amount < 0) - amount = 0; - color[c][n].x += p_octree[cell].light[i][0] / 1024.0 * amount; - color[c][n].y += p_octree[cell].light[i][1] / 1024.0 * amount; - color[c][n].z += p_octree[cell].light[i][2] / 1024.0 * amount; - } + rot.rotate_sh(csh); + for (int j = 0; j < 9; j++) { + sh[j][i] = csh[j]; } + } - //print_line("\tlev " + itos(c) + " - " + itos(n) + " alpha: " + rtos(cells[test_cell].alpha) + " col: " + color[c][n]); + Vector3 inner_pos = ((lm_pos - bounds.position) / bounds.size) * 2.0 - Vector3(1.0, 1.0, 1.0); + + float blend = MAX(inner_pos.x, MAX(inner_pos.y, inner_pos.z)); + //make blend more rounded + blend = Math::lerp(inner_pos.length(), blend, blend); + blend *= blend; + blend = MAX(0.0, 1.0 - blend); + + if (interior && !inside) { + //do not blend, just replace + for (int j = 0; j < 9; j++) { + accum_sh[j] = sh[j] * blend; + } + accum_blend = blend; + inside = true; + } else { + for (int j = 0; j < 9; j++) { + accum_sh[j] += sh[j] * blend; + } + accum_blend += blend; } } - float target_level_size = size >> target_level; - Vector3 pos_fract[2]; + if (accum_blend > 0.0) { + for (int j = 0; j < 9; j++) { - pos_fract[0].x = Math::fmod(pos.x, target_level_size) / target_level_size; - pos_fract[0].y = Math::fmod(pos.y, target_level_size) / target_level_size; - pos_fract[0].z = Math::fmod(pos.z, target_level_size) / target_level_size; - - target_level_size = size >> MAX(0, target_level - 1); - - pos_fract[1].x = Math::fmod(pos.x, target_level_size) / target_level_size; - pos_fract[1].y = Math::fmod(pos.y, target_level_size) / target_level_size; - pos_fract[1].z = Math::fmod(pos.z, target_level_size) / target_level_size; - - float alpha_interp[2]; - Vector3 color_interp[2]; - - for (int i = 0; i < 2; i++) { - - Vector3 color_x00 = color[i][0].lerp(color[i][1], pos_fract[i].x); - Vector3 color_xy0 = color[i][2].lerp(color[i][3], pos_fract[i].x); - Vector3 blend_z0 = color_x00.lerp(color_xy0, pos_fract[i].y); - - Vector3 color_x0z = color[i][4].lerp(color[i][5], pos_fract[i].x); - Vector3 color_xyz = color[i][6].lerp(color[i][7], pos_fract[i].x); - Vector3 blend_z1 = color_x0z.lerp(color_xyz, pos_fract[i].y); - - color_interp[i] = blend_z0.lerp(blend_z1, pos_fract[i].z); - - float alpha_x00 = Math::lerp(alpha[i][0], alpha[i][1], pos_fract[i].x); - float alpha_xy0 = Math::lerp(alpha[i][2], alpha[i][3], pos_fract[i].x); - float alpha_z0 = Math::lerp(alpha_x00, alpha_xy0, pos_fract[i].y); - - float alpha_x0z = Math::lerp(alpha[i][4], alpha[i][5], pos_fract[i].x); - float alpha_xyz = Math::lerp(alpha[i][6], alpha[i][7], pos_fract[i].x); - float alpha_z1 = Math::lerp(alpha_x0z, alpha_xyz, pos_fract[i].y); - - alpha_interp[i] = Math::lerp(alpha_z0, alpha_z1, pos_fract[i].z); - } - - r_color = color_interp[0].lerp(color_interp[1], level_filter); - r_alpha = Math::lerp(alpha_interp[0], alpha_interp[1], level_filter); - - //print_line("pos: " + p_posf + " level " + rtos(p_level) + " down to " + itos(target_level) + "." + rtos(level_filter) + " color " + r_color + " alpha " + rtos(r_alpha)); -} - -_FORCE_INLINE_ static Color _light_capture_voxel_cone_trace(const RasterizerStorage::LightmapCaptureOctree *p_octree, const Vector3 &p_pos, const Vector3 &p_dir, float p_aperture, int p_cell_subdiv) { - - float bias = 0.0; //no need for bias here - float max_distance = (Vector3(1, 1, 1) * (1 << (p_cell_subdiv - 1))).length(); - - float dist = bias; - float alpha = 0.0; - Vector3 color; - - Vector3 scolor; - float salpha; - - while (dist < max_distance && alpha < 0.95) { - float diameter = MAX(1.0, 2.0 * p_aperture * dist); - _light_capture_sample_octree(p_octree, p_cell_subdiv, p_pos + dist * p_dir, p_dir, log2(diameter), scolor, salpha); - float a = (1.0 - alpha); - color += scolor * a; - alpha += a * salpha; - dist += diameter * 0.5; - } - - return Color(color.x, color.y, color.z, alpha); -} - -void RenderingServerScene::_update_instance_lightmap_captures(Instance *p_instance) { - - InstanceGeometryData *geom = static_cast(p_instance->base_data); - - static const Vector3 cone_traces[12] = { - Vector3(0, 0, 1), - Vector3(0.866025, 0, 0.5), - Vector3(0.267617, 0.823639, 0.5), - Vector3(-0.700629, 0.509037, 0.5), - Vector3(-0.700629, -0.509037, 0.5), - Vector3(0.267617, -0.823639, 0.5), - Vector3(0, 0, -1), - Vector3(0.866025, 0, -0.5), - Vector3(0.267617, 0.823639, -0.5), - Vector3(-0.700629, 0.509037, -0.5), - Vector3(-0.700629, -0.509037, -0.5), - Vector3(0.267617, -0.823639, -0.5) - }; - - float cone_aperture = 0.577; // tan(angle) 60 degrees - - if (p_instance->lightmap_capture_data.empty()) { - p_instance->lightmap_capture_data.resize(12); - } - - //print_line("update captures for pos: " + p_instance->transform.origin); - - for (int i = 0; i < 12; i++) - new (&p_instance->lightmap_capture_data.ptrw()[i]) Color; - - //this could use some sort of blending.. - for (List::Element *E = geom->lightmap_captures.front(); E; E = E->next()) { - const Vector *octree = RSG::storage->lightmap_capture_get_octree_ptr(E->get()->base); - //print_line("octree size: " + itos(octree->size())); - if (octree->size() == 0) - continue; - Transform to_cell_xform = RSG::storage->lightmap_capture_get_octree_cell_transform(E->get()->base); - int cell_subdiv = RSG::storage->lightmap_capture_get_octree_cell_subdiv(E->get()->base); - to_cell_xform = to_cell_xform * E->get()->transform.affine_inverse(); - - const RasterizerStorage::LightmapCaptureOctree *octree_r = octree->ptr(); - - Vector3 pos = to_cell_xform.xform(p_instance->transform.origin); - - for (int i = 0; i < 12; i++) { - - Vector3 dir = to_cell_xform.basis.xform(cone_traces[i]).normalized(); - Color capture = _light_capture_voxel_cone_trace(octree_r, pos, dir, cone_aperture, cell_subdiv); - p_instance->lightmap_capture_data.write[i] += capture; + instance_sh[j] = accum_sh[j] / accum_blend; + if (first_set) { + p_instance->lightmap_sh.write[j] = instance_sh[j]; + } } } } @@ -1762,10 +1618,6 @@ bool RenderingServerScene::_light_instance_update_shadow(Instance *p_instance, c } else { camera_matrix_square.set_frustum(vp_he.y * 2.0, 1.0, Vector2(), distances[(i == 0 || !overlap) ? i : i - 1], distances[i + 1], false); } - - if (i == 0) { - //print_line("prev he: " + vp_he + " new he: " + camera_matrix_square.get_viewport_half_extents()); - } } Vector3 endpoints_square[8]; // frustum plane endpoints @@ -2147,6 +1999,7 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const reflection_probe_cull_count = 0; decal_cull_count = 0; gi_probe_cull_count = 0; + lightmap_cull_count = 0; //light_samplers_culled=0; @@ -2161,6 +2014,8 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const //removed, will replace with culling /* STEP 4 - REMOVE FURTHER CULLED OBJECTS, ADD LIGHTS */ + uint64_t frame_number = RSG::rasterizer->get_frame_number(); + float lightmap_probe_update_speed = RSG::storage->lightmap_get_probe_capture_update_speed() * RSG::rasterizer->get_frame_delta_time(); for (int i = 0; i < instance_cull_count; i++) { @@ -2239,6 +2094,12 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const gi_probe_instance_cull_result[gi_probe_cull_count] = gi_probe->probe_instance; gi_probe_cull_count++; } + } else if (ins->base_type == RS::INSTANCE_LIGHTMAP && ins->visible) { + + if (lightmap_cull_count < MAX_LIGHTMAPS_CULLED) { + lightmap_cull_result[lightmap_cull_count] = ins; + lightmap_cull_count++; + } } else if (((1 << ins->base_type) & RS::INSTANCE_GEOMETRY_MASK) && ins->visible && ins->cast_shadows != RS::SHADOW_CASTING_SETTING_SHADOWS_ONLY) { @@ -2307,6 +2168,14 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const geom->gi_probes_dirty = false; } + if (ins->last_frame_pass != frame_number && !ins->lightmap_target_sh.empty() && !ins->lightmap_sh.empty()) { + Color *sh = ins->lightmap_sh.ptrw(); + const Color *target_sh = ins->lightmap_target_sh.ptr(); + for (uint32_t j = 0; j < 9; j++) { + sh[j] = sh[j].lerp(target_sh[j], MIN(1.0, lightmap_probe_update_speed)); + } + } + ins->depth = near_plane.distance_to(ins->transform.origin); ins->depth_layer = CLAMP(int(ins->depth * 16 / z_far), 0, 15); } @@ -2321,6 +2190,7 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const ins->last_render_pass = render_pass; } + ins->last_frame_pass = frame_number; } /* STEP 5 - PROCESS LIGHTS */ @@ -2494,7 +2364,7 @@ void RenderingServerScene::_render_scene(RID p_render_buffers, const Transform p /* PROCESS GEOMETRY AND DRAW SCENE */ RENDER_TIMESTAMP("Render Scene "); - RSG::scene_render->render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_orthogonal, (RasterizerScene::InstanceBase **)instance_cull_result, instance_cull_count, light_instance_cull_result, light_cull_count + directional_light_count, reflection_probe_instance_cull_result, reflection_probe_cull_count, gi_probe_instance_cull_result, gi_probe_cull_count, decal_instance_cull_result, decal_cull_count, environment, camera_effects, p_shadow_atlas, p_reflection_probe.is_valid() ? RID() : scenario->reflection_atlas, p_reflection_probe, p_reflection_probe_pass); + RSG::scene_render->render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_orthogonal, (RasterizerScene::InstanceBase **)instance_cull_result, instance_cull_count, light_instance_cull_result, light_cull_count + directional_light_count, reflection_probe_instance_cull_result, reflection_probe_cull_count, gi_probe_instance_cull_result, gi_probe_cull_count, decal_instance_cull_result, decal_cull_count, (RasterizerScene::InstanceBase **)lightmap_cull_result, lightmap_cull_count, environment, camera_effects, p_shadow_atlas, p_reflection_probe.is_valid() ? RID() : scenario->reflection_atlas, p_reflection_probe, p_reflection_probe_pass); } void RenderingServerScene::render_empty_scene(RID p_render_buffers, RID p_scenario, RID p_shadow_atlas) { @@ -2509,7 +2379,7 @@ void RenderingServerScene::render_empty_scene(RID p_render_buffers, RID p_scenar else environment = scenario->fallback_environment; RENDER_TIMESTAMP("Render Empty Scene "); - RSG::scene_render->render_scene(p_render_buffers, Transform(), CameraMatrix(), true, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, environment, RID(), p_shadow_atlas, scenario->reflection_atlas, RID(), 0); + RSG::scene_render->render_scene(p_render_buffers, Transform(), CameraMatrix(), true, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, environment, RID(), p_shadow_atlas, scenario->reflection_atlas, RID(), 0); #endif } @@ -3123,7 +2993,7 @@ bool RenderingServerScene::free(RID p_rid) { Instance *instance = instance_owner.getornull(p_rid); - instance_set_use_lightmap(p_rid, RID(), RID()); + instance_geometry_set_lightmap(p_rid, RID(), Rect2(), 0); instance_set_scenario(p_rid, RID()); instance_set_base(p_rid, RID()); instance_geometry_set_material_override(p_rid, RID()); @@ -3144,6 +3014,10 @@ bool RenderingServerScene::free(RID p_rid) { return true; } +TypedArray RenderingServerScene::bake_render_uv2(RID p_base, const Vector &p_material_overrides, const Size2i &p_image_size) { + return RSG::scene_render->bake_render_uv2(p_base, p_material_overrides, p_image_size); +} + RenderingServerScene *RenderingServerScene::singleton = nullptr; RenderingServerScene::RenderingServerScene() { diff --git a/servers/rendering/rendering_server_scene.h b/servers/rendering/rendering_server_scene.h index eb66cea3aa..df9e650ac7 100644 --- a/servers/rendering/rendering_server_scene.h +++ b/servers/rendering/rendering_server_scene.h @@ -51,6 +51,7 @@ public: MAX_DECALS_CULLED = 4096, MAX_GI_PROBES_CULLED = 4096, MAX_ROOM_CULL = 32, + MAX_LIGHTMAPS_CULLED = 4096, MAX_EXTERIOR_PORTALS = 128, }; @@ -171,6 +172,8 @@ public: float lod_end_hysteresis; RID lod_instance; + Vector lightmap_target_sh; //target is used for incrementally changing the SH over time, this avoids pops in some corner cases and when going interior <-> exterior + uint64_t last_render_pass; uint64_t last_frame_pass; @@ -374,7 +377,7 @@ public: SelfList::List gi_probe_update_list; - struct InstanceLightmapCaptureData : public InstanceBaseData { + struct InstanceLightmapData : public InstanceBaseData { struct PairInfo { List::Element *L; //iterator in geometry @@ -384,7 +387,7 @@ public: Set users; - InstanceLightmapCaptureData() { + InstanceLightmapData() { } }; @@ -401,6 +404,8 @@ public: int decal_cull_count; RID gi_probe_instance_cull_result[MAX_GI_PROBES_CULLED]; int gi_probe_cull_count; + Instance *lightmap_cull_result[MAX_LIGHTS_CULLED]; + int lightmap_cull_count; RID_PtrOwner instance_owner; @@ -414,7 +419,6 @@ public: virtual void instance_set_blend_shape_weight(RID p_instance, int p_shape, float p_weight); virtual void instance_set_surface_material(RID p_instance, int p_surface, RID p_material); virtual void instance_set_visible(RID p_instance, bool p_visible); - virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap); virtual void instance_set_custom_aabb(RID p_instance, AABB p_aabb); @@ -434,6 +438,7 @@ public: virtual void instance_geometry_set_draw_range(RID p_instance, float p_min, float p_max, float p_min_margin, float p_max_margin); virtual void instance_geometry_set_as_instance_lod(RID p_instance, RID p_as_lod_of_instance); + virtual void instance_geometry_set_lightmap(RID p_instance, RID p_lightmap, const Rect2 &p_lightmap_uv_scale, int p_slice_index); void _update_instance_shader_parameters_from_material(Map &isparams, const Map &existing_isparams, RID p_material); @@ -460,6 +465,8 @@ public: void render_probes(); + TypedArray bake_render_uv2(RID p_base, const Vector &p_material_overrides, const Size2i &p_image_size); + bool free(RID p_rid); RenderingServerScene(); diff --git a/servers/rendering/rendering_server_wrap_mt.cpp b/servers/rendering/rendering_server_wrap_mt.cpp index 4ca13dbef9..a7a166bec6 100644 --- a/servers/rendering/rendering_server_wrap_mt.cpp +++ b/servers/rendering/rendering_server_wrap_mt.cpp @@ -138,7 +138,7 @@ void RenderingServerWrapMT::finish() { spot_light_free_cached_ids(); reflection_probe_free_cached_ids(); gi_probe_free_cached_ids(); - lightmap_capture_free_cached_ids(); + lightmap_free_cached_ids(); particles_free_cached_ids(); camera_free_cached_ids(); viewport_free_cached_ids(); diff --git a/servers/rendering/rendering_server_wrap_mt.h b/servers/rendering/rendering_server_wrap_mt.h index d4e58485b8..6580b71508 100644 --- a/servers/rendering/rendering_server_wrap_mt.h +++ b/servers/rendering/rendering_server_wrap_mt.h @@ -92,7 +92,7 @@ public: //these also go pass-through virtual RID texture_2d_placeholder_create() { return rendering_server->texture_2d_placeholder_create(); } - virtual RID texture_2d_layered_placeholder_create() { return rendering_server->texture_2d_layered_placeholder_create(); } + virtual RID texture_2d_layered_placeholder_create(TextureLayeredType p_type) { return rendering_server->texture_2d_layered_placeholder_create(p_type); } virtual RID texture_3d_placeholder_create() { return rendering_server->texture_3d_placeholder_create(); } FUNC1RC(Ref, texture_2d_get, RID) @@ -324,19 +324,17 @@ public: /* LIGHTMAP CAPTURE */ - FUNCRID(lightmap_capture) + FUNCRID(lightmap) + FUNC3(lightmap_set_textures, RID, RID, bool) + FUNC2(lightmap_set_probe_bounds, RID, const AABB &) + FUNC2(lightmap_set_probe_interior, RID, bool) + FUNC5(lightmap_set_probe_capture_data, RID, const PackedVector3Array &, const PackedColorArray &, const PackedInt32Array &, const PackedInt32Array &) + FUNC1RC(PackedVector3Array, lightmap_get_probe_capture_points, RID) + FUNC1RC(PackedColorArray, lightmap_get_probe_capture_sh, RID) + FUNC1RC(PackedInt32Array, lightmap_get_probe_capture_tetrahedra, RID) + FUNC1RC(PackedInt32Array, lightmap_get_probe_capture_bsp_tree, RID) - FUNC2(lightmap_capture_set_bounds, RID, const AABB &) - FUNC1RC(AABB, lightmap_capture_get_bounds, RID) - - FUNC2(lightmap_capture_set_octree, RID, const Vector &) - FUNC1RC(Vector, lightmap_capture_get_octree, RID) - FUNC2(lightmap_capture_set_octree_cell_transform, RID, const Transform &) - FUNC1RC(Transform, lightmap_capture_get_octree_cell_transform, RID) - FUNC2(lightmap_capture_set_octree_cell_subdiv, RID, int) - FUNC1RC(int, lightmap_capture_get_octree_cell_subdiv, RID) - FUNC2(lightmap_capture_set_energy, RID, float) - FUNC1RC(float, lightmap_capture_get_energy, RID) + FUNC1(lightmap_set_probe_capture_update_speed, float) /* PARTICLES */ @@ -442,6 +440,7 @@ public: FUNC2(sky_set_radiance_size, RID, int) FUNC2(sky_set_mode, RID, SkyMode) FUNC2(sky_set_material, RID, RID) + FUNC4R(Ref, sky_bake_panorama, RID, float, bool, const Size2i &) /* ENVIRONMENT API */ @@ -478,6 +477,8 @@ public: FUNC7(environment_set_fog_depth, RID, bool, float, float, float, bool, float) FUNC5(environment_set_fog_height, RID, bool, float, float, float) + FUNC3R(Ref, environment_bake_panorama, RID, bool, const Size2i &) + FUNC2(screen_space_roughness_limiter_set_active, bool, float) FUNC1(sub_surface_scattering_set_quality, SubSurfaceScatteringQuality) FUNC2(sub_surface_scattering_set_scale, float, float) @@ -511,7 +512,6 @@ public: FUNC3(instance_set_blend_shape_weight, RID, int, float) FUNC3(instance_set_surface_material, RID, int, RID) FUNC2(instance_set_visible, RID, bool) - FUNC3(instance_set_use_lightmap, RID, RID, RID) FUNC2(instance_set_custom_aabb, RID, AABB) @@ -531,12 +531,17 @@ public: FUNC5(instance_geometry_set_draw_range, RID, float, float, float, float) FUNC2(instance_geometry_set_as_instance_lod, RID, RID) + FUNC4(instance_geometry_set_lightmap, RID, RID, const Rect2 &, int) FUNC3(instance_geometry_set_shader_parameter, RID, const StringName &, const Variant &) FUNC2RC(Variant, instance_geometry_get_shader_parameter, RID, const StringName &) FUNC2RC(Variant, instance_geometry_get_shader_parameter_default_value, RID, const StringName &) FUNC2SC(instance_geometry_get_shader_parameter_list, RID, List *) + /* BAKE */ + + FUNC3R(TypedArray, bake_render_uv2, RID, const Vector &, const Size2i &) + /* CANVAS (2D) */ FUNCRID(canvas) diff --git a/servers/rendering/shader_language.cpp b/servers/rendering/shader_language.cpp index 2a5492d93f..e3725043d9 100644 --- a/servers/rendering/shader_language.cpp +++ b/servers/rendering/shader_language.cpp @@ -132,6 +132,7 @@ const char *ShaderLanguage::token_names[TK_MAX] = { "TYPE_ISAMPLER3D", "TYPE_USAMPLER3D", "TYPE_SAMPLERCUBE", + "TYPE_SAMPLERCUBEARRAY", "INTERPOLATION_FLAT", "INTERPOLATION_SMOOTH", "CONST", @@ -283,6 +284,7 @@ const ShaderLanguage::KeyWord ShaderLanguage::keyword_list[] = { { TK_TYPE_ISAMPLER3D, "isampler3D" }, { TK_TYPE_USAMPLER3D, "usampler3D" }, { TK_TYPE_SAMPLERCUBE, "samplerCube" }, + { TK_TYPE_SAMPLERCUBEARRAY, "samplerCubeArray" }, { TK_INTERPOLATION_FLAT, "flat" }, { TK_INTERPOLATION_SMOOTH, "smooth" }, { TK_CONST, "const" }, @@ -783,7 +785,8 @@ bool ShaderLanguage::is_token_datatype(TokenType p_type) { p_type == TK_TYPE_SAMPLER3D || p_type == TK_TYPE_ISAMPLER3D || p_type == TK_TYPE_USAMPLER3D || - p_type == TK_TYPE_SAMPLERCUBE); + p_type == TK_TYPE_SAMPLERCUBE || + p_type == TK_TYPE_SAMPLERCUBEARRAY); } ShaderLanguage::DataType ShaderLanguage::get_token_datatype(TokenType p_type) { @@ -902,6 +905,8 @@ String ShaderLanguage::get_datatype_name(DataType p_type) { return "usampler3D"; case TYPE_SAMPLERCUBE: return "samplerCube"; + case TYPE_SAMPLERCUBEARRAY: + return "samplerCubeArray"; case TYPE_STRUCT: return "struct"; case TYPE_MAX: @@ -2046,6 +2051,7 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "textureSize", TYPE_IVEC3, { TYPE_ISAMPLER3D, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureSize", TYPE_IVEC3, { TYPE_USAMPLER3D, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureSize", TYPE_IVEC2, { TYPE_SAMPLERCUBE, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, + { "textureSize", TYPE_IVEC2, { TYPE_SAMPLERCUBEARRAY, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "texture", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC2, TYPE_VOID }, TAG_GLOBAL, false }, { "texture", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC2, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, @@ -2067,6 +2073,8 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "texture", TYPE_IVEC4, { TYPE_ISAMPLER3D, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, false }, { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, + { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_VOID }, TAG_GLOBAL, false }, + { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, { "textureProj", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "textureProj", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC4, TYPE_VOID }, TAG_GLOBAL, true }, @@ -2097,6 +2105,7 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "textureLod", TYPE_IVEC4, { TYPE_ISAMPLER3D, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureLod", TYPE_UVEC4, { TYPE_USAMPLER3D, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureLod", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, + { "textureLod", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, { "texelFetch", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_IVEC2, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "texelFetch", TYPE_IVEC4, { TYPE_ISAMPLER2D, TYPE_IVEC2, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, @@ -2128,6 +2137,7 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "textureGrad", TYPE_IVEC4, { TYPE_ISAMPLER3D, TYPE_VEC3, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "textureGrad", TYPE_UVEC4, { TYPE_USAMPLER3D, TYPE_VEC3, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "textureGrad", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, + { "textureGrad", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "dFdx", TYPE_FLOAT, { TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "dFdx", TYPE_VEC2, { TYPE_VEC2, TYPE_VOID }, TAG_GLOBAL, true }, @@ -2618,7 +2628,8 @@ bool ShaderLanguage::is_sampler_type(DataType p_type) { p_type == TYPE_SAMPLER3D || p_type == TYPE_ISAMPLER3D || p_type == TYPE_USAMPLER3D || - p_type == TYPE_SAMPLERCUBE; + p_type == TYPE_SAMPLERCUBE || + p_type == TYPE_SAMPLERCUBEARRAY; } Variant ShaderLanguage::constant_value_to_variant(const Vector &p_value, DataType p_type, ShaderLanguage::ShaderNode::Uniform::Hint p_hint) { @@ -2712,7 +2723,9 @@ Variant ShaderLanguage::constant_value_to_variant(const Vector &p_funct return ERR_PARSE_ERROR; } } else { - if (uniform_scope == ShaderNode::Uniform::SCOPE_LOCAL && (type == TYPE_MAT2 || type == TYPE_MAT3 || type == TYPE_MAT4)) { + if (uniform_scope == ShaderNode::Uniform::SCOPE_INSTANCE && (type == TYPE_MAT2 || type == TYPE_MAT3 || type == TYPE_MAT4)) { _set_error("Uniforms with 'instance' qualifiers can't be of matrix type."); return ERR_PARSE_ERROR; } diff --git a/servers/rendering/shader_language.h b/servers/rendering/shader_language.h index 973e1c4937..314e4a5fba 100644 --- a/servers/rendering/shader_language.h +++ b/servers/rendering/shader_language.h @@ -79,6 +79,7 @@ public: TK_TYPE_ISAMPLER3D, TK_TYPE_USAMPLER3D, TK_TYPE_SAMPLERCUBE, + TK_TYPE_SAMPLERCUBEARRAY, TK_INTERPOLATION_FLAT, TK_INTERPOLATION_SMOOTH, TK_CONST, @@ -218,6 +219,7 @@ public: TYPE_ISAMPLER3D, TYPE_USAMPLER3D, TYPE_SAMPLERCUBE, + TYPE_SAMPLERCUBEARRAY, TYPE_STRUCT, TYPE_MAX }; @@ -682,6 +684,7 @@ public: texture_order(0), type(TYPE_VOID), precision(PRECISION_DEFAULT), + scope(SCOPE_LOCAL), hint(HINT_NONE), filter(FILTER_DEFAULT), repeat(REPEAT_DEFAULT), diff --git a/servers/rendering_server.cpp b/servers/rendering_server.cpp index e170b66562..3dac846357 100644 --- a/servers/rendering_server.cpp +++ b/servers/rendering_server.cpp @@ -1784,8 +1784,8 @@ void RenderingServer::_bind_methods() { ClassDB::bind_method(D_METHOD("gi_probe_set_compress", "probe", "enable"), &RenderingServer::gi_probe_set_compress); ClassDB::bind_method(D_METHOD("gi_probe_is_compressed", "probe"), &RenderingServer::gi_probe_is_compressed); #endif - - ClassDB::bind_method(D_METHOD("lightmap_capture_create"), &RenderingServer::lightmap_capture_create); +/* + ClassDB::bind_method(D_METHOD("lightmap_create()"), &RenderingServer::lightmap_capture_create); ClassDB::bind_method(D_METHOD("lightmap_capture_set_bounds", "capture", "bounds"), &RenderingServer::lightmap_capture_set_bounds); ClassDB::bind_method(D_METHOD("lightmap_capture_get_bounds", "capture"), &RenderingServer::lightmap_capture_get_bounds); ClassDB::bind_method(D_METHOD("lightmap_capture_set_octree", "capture", "octree"), &RenderingServer::lightmap_capture_set_octree); @@ -1796,6 +1796,7 @@ void RenderingServer::_bind_methods() { ClassDB::bind_method(D_METHOD("lightmap_capture_get_octree", "capture"), &RenderingServer::lightmap_capture_get_octree); ClassDB::bind_method(D_METHOD("lightmap_capture_set_energy", "capture", "energy"), &RenderingServer::lightmap_capture_set_energy); ClassDB::bind_method(D_METHOD("lightmap_capture_get_energy", "capture"), &RenderingServer::lightmap_capture_get_energy); +*/ #endif ClassDB::bind_method(D_METHOD("particles_create"), &RenderingServer::particles_create); ClassDB::bind_method(D_METHOD("particles_set_emitting", "particles", "emitting"), &RenderingServer::particles_set_emitting); @@ -1895,7 +1896,7 @@ void RenderingServer::_bind_methods() { ClassDB::bind_method(D_METHOD("instance_set_blend_shape_weight", "instance", "shape", "weight"), &RenderingServer::instance_set_blend_shape_weight); ClassDB::bind_method(D_METHOD("instance_set_surface_material", "instance", "surface", "material"), &RenderingServer::instance_set_surface_material); ClassDB::bind_method(D_METHOD("instance_set_visible", "instance", "visible"), &RenderingServer::instance_set_visible); - ClassDB::bind_method(D_METHOD("instance_set_use_lightmap", "instance", "lightmap_instance", "lightmap"), &RenderingServer::instance_set_use_lightmap); + // ClassDB::bind_method(D_METHOD("instance_set_use_lightmap", "instance", "lightmap_instance", "lightmap"), &RenderingServer::instance_set_use_lightmap); ClassDB::bind_method(D_METHOD("instance_set_custom_aabb", "instance", "aabb"), &RenderingServer::instance_set_custom_aabb); ClassDB::bind_method(D_METHOD("instance_attach_skeleton", "instance", "skeleton"), &RenderingServer::instance_attach_skeleton); ClassDB::bind_method(D_METHOD("instance_set_exterior", "instance", "enabled"), &RenderingServer::instance_set_exterior); @@ -2266,7 +2267,7 @@ void RenderingServer::_bind_methods() { BIND_ENUM_CONSTANT(INSTANCE_REFLECTION_PROBE); BIND_ENUM_CONSTANT(INSTANCE_DECAL); BIND_ENUM_CONSTANT(INSTANCE_GI_PROBE); - BIND_ENUM_CONSTANT(INSTANCE_LIGHTMAP_CAPTURE); + BIND_ENUM_CONSTANT(INSTANCE_LIGHTMAP); BIND_ENUM_CONSTANT(INSTANCE_MAX); BIND_ENUM_CONSTANT(INSTANCE_GEOMETRY_MASK); @@ -2513,6 +2514,9 @@ RenderingServer::RenderingServer() { ProjectSettings::get_singleton()->set_custom_property_info("rendering/quality/subsurface_scattering/subsurface_scattering_depth_scale", PropertyInfo(Variant::FLOAT, "rendering/quality/subsurface_scattering/subsurface_scattering_depth_scale", PROPERTY_HINT_RANGE, "0.001,1,0.001")); GLOBAL_DEF("rendering/high_end/global_shader_variables_buffer_size", 65536); + + GLOBAL_DEF("rendering/lightmapper/probe_capture_update_speed", 15); + ProjectSettings::get_singleton()->set_custom_property_info("rendering/lightmapper/probe_capture_update_speed", PropertyInfo(Variant::FLOAT, "rendering/lightmapper/probe_capture_update_speed", PROPERTY_HINT_RANGE, "0.001,256,0.001")); } RenderingServer::~RenderingServer() { diff --git a/servers/rendering_server.h b/servers/rendering_server.h index 8ca070b4a9..d426f205d0 100644 --- a/servers/rendering_server.h +++ b/servers/rendering_server.h @@ -36,6 +36,7 @@ #include "core/math/transform_2d.h" #include "core/object.h" #include "core/rid.h" +#include "core/typed_array.h" #include "core/variant.h" #include "servers/display_server.h" #include "servers/rendering/shader_language.h" @@ -106,7 +107,7 @@ public: //these two APIs can be used together or in combination with the others. virtual RID texture_2d_placeholder_create() = 0; - virtual RID texture_2d_layered_placeholder_create() = 0; + virtual RID texture_2d_layered_placeholder_create(TextureLayeredType p_layered_type) = 0; virtual RID texture_3d_placeholder_create() = 0; virtual Ref texture_2d_get(RID p_texture) const = 0; @@ -522,19 +523,20 @@ public: virtual void gi_probe_set_anisotropy_strength(RID p_gi_probe, float p_strength) = 0; virtual float gi_probe_get_anisotropy_strength(RID p_gi_probe) const = 0; - /* LIGHTMAP CAPTURE */ + /* LIGHTMAP */ - virtual RID lightmap_capture_create() = 0; - virtual void lightmap_capture_set_bounds(RID p_capture, const AABB &p_bounds) = 0; - virtual AABB lightmap_capture_get_bounds(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree(RID p_capture, const Vector &p_octree) = 0; - virtual void lightmap_capture_set_octree_cell_transform(RID p_capture, const Transform &p_xform) = 0; - virtual Transform lightmap_capture_get_octree_cell_transform(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree_cell_subdiv(RID p_capture, int p_subdiv) = 0; - virtual int lightmap_capture_get_octree_cell_subdiv(RID p_capture) const = 0; - virtual Vector lightmap_capture_get_octree(RID p_capture) const = 0; - virtual void lightmap_capture_set_energy(RID p_capture, float p_energy) = 0; - virtual float lightmap_capture_get_energy(RID p_capture) const = 0; + virtual RID lightmap_create() = 0; + + virtual void lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics) = 0; + virtual void lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds) = 0; + virtual void lightmap_set_probe_interior(RID p_lightmap, bool p_interior) = 0; + virtual void lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) = 0; + virtual PackedVector3Array lightmap_get_probe_capture_points(RID p_lightmap) const = 0; + virtual PackedColorArray lightmap_get_probe_capture_sh(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const = 0; + + virtual void lightmap_set_probe_capture_update_speed(float p_speed) = 0; /* PARTICLES API */ @@ -713,6 +715,7 @@ public: virtual void sky_set_radiance_size(RID p_sky, int p_radiance_size) = 0; virtual void sky_set_mode(RID p_sky, SkyMode p_mode) = 0; virtual void sky_set_material(RID p_sky, RID p_material) = 0; + virtual Ref sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size) = 0; /* ENVIRONMENT API */ @@ -809,6 +812,8 @@ public: virtual void environment_set_fog_depth(RID p_env, bool p_enable, float p_depth_begin, float p_depth_end, float p_depth_curve, bool p_transmit, float p_transmit_curve) = 0; virtual void environment_set_fog_height(RID p_env, bool p_enable, float p_min_height, float p_max_height, float p_height_curve) = 0; + virtual Ref environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size) = 0; + virtual void screen_space_roughness_limiter_set_active(bool p_enable, float p_curve) = 0; enum SubSurfaceScatteringQuality { @@ -885,7 +890,7 @@ public: INSTANCE_REFLECTION_PROBE, INSTANCE_DECAL, INSTANCE_GI_PROBE, - INSTANCE_LIGHTMAP_CAPTURE, + INSTANCE_LIGHTMAP, INSTANCE_MAX, INSTANCE_GEOMETRY_MASK = (1 << INSTANCE_MESH) | (1 << INSTANCE_MULTIMESH) | (1 << INSTANCE_IMMEDIATE) | (1 << INSTANCE_PARTICLES) @@ -904,8 +909,6 @@ public: virtual void instance_set_surface_material(RID p_instance, int p_surface, RID p_material) = 0; virtual void instance_set_visible(RID p_instance, bool p_visible) = 0; - virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap) = 0; - virtual void instance_set_custom_aabb(RID p_instance, AABB aabb) = 0; virtual void instance_attach_skeleton(RID p_instance, RID p_skeleton) = 0; @@ -942,12 +945,24 @@ public: virtual void instance_geometry_set_draw_range(RID p_instance, float p_min, float p_max, float p_min_margin, float p_max_margin) = 0; virtual void instance_geometry_set_as_instance_lod(RID p_instance, RID p_as_lod_of_instance) = 0; + virtual void instance_geometry_set_lightmap(RID p_instance, RID p_lightmap, const Rect2 &p_lightmap_uv_scale, int p_lightmap_slice) = 0; virtual void instance_geometry_set_shader_parameter(RID p_instance, const StringName &, const Variant &p_value) = 0; virtual Variant instance_geometry_get_shader_parameter(RID p_instance, const StringName &) const = 0; virtual Variant instance_geometry_get_shader_parameter_default_value(RID p_instance, const StringName &) const = 0; virtual void instance_geometry_get_shader_parameter_list(RID p_instance, List *p_parameters) const = 0; + /* Bake 3D objects */ + + enum BakeChannels { + BAKE_CHANNEL_ALBEDO_ALPHA, + BAKE_CHANNEL_NORMAL, + BAKE_CHANNEL_ORM, + BAKE_CHANNEL_EMISSION + }; + + virtual TypedArray bake_render_uv2(RID p_base, const Vector &p_material_overrides, const Size2i &p_image_size) = 0; + /* CANVAS (2D) */ virtual RID canvas_create() = 0; @@ -1186,8 +1201,6 @@ public: virtual Vector get_frame_profile() = 0; virtual uint64_t get_frame_profile_frame() = 0; - /* Materials for 2D on 3D */ - /* TESTING */ virtual RID get_test_cube() = 0; diff --git a/thirdparty/oidn/.gitignore b/thirdparty/oidn/.gitignore new file mode 100644 index 0000000000..6be206fc29 --- /dev/null +++ b/thirdparty/oidn/.gitignore @@ -0,0 +1 @@ +weights/rtlightmap_hdr.cpp diff --git a/thirdparty/oidn/common/barrier.h b/thirdparty/oidn/common/barrier.h new file mode 100644 index 0000000000..b20f670053 --- /dev/null +++ b/thirdparty/oidn/common/barrier.h @@ -0,0 +1,52 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include +#include + +namespace oidn { + + class Barrier + { + private: + std::mutex m; + std::condition_variable cv; + volatile int count; + + public: + Barrier(int count) : count(count) {} + + void wait() + { + std::unique_lock lk(m); + count--; + + if (count == 0) + { + lk.unlock(); + cv.notify_all(); + } + else + { + cv.wait(lk, [&]{ return count == 0; }); + } + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/exception.h b/thirdparty/oidn/common/exception.h new file mode 100644 index 0000000000..18069c6a7d --- /dev/null +++ b/thirdparty/oidn/common/exception.h @@ -0,0 +1,45 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include "platform.h" + +namespace oidn { + + class Exception : public std::exception + { + private: + Error error; + const char* message; + + public: + Exception(Error error, const char* message) + : error(error), message(message) {} + + Error code() const noexcept + { + return error; + } + + const char* what() const noexcept override + { + return message; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.cpp b/thirdparty/oidn/common/platform.cpp new file mode 100644 index 0000000000..59a14ff47c --- /dev/null +++ b/thirdparty/oidn/common/platform.cpp @@ -0,0 +1,114 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "platform.h" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + void* alignedMalloc(size_t size, size_t alignment) + { + if (size == 0) + return nullptr; + + assert((alignment & (alignment-1)) == 0); + void* ptr = _mm_malloc(size, alignment); + + if (ptr == nullptr) + throw std::bad_alloc(); + + return ptr; + } + + void alignedFree(void* ptr) + { + if (ptr) + _mm_free(ptr); + } + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName() + { + std::string name; + + #if defined(__linux__) + name = "Linux"; + #elif defined(__FreeBSD__) + name = "FreeBSD"; + #elif defined(__CYGWIN__) + name = "Cygwin"; + #elif defined(_WIN32) + name = "Windows"; + #elif defined(__APPLE__) + name = "macOS"; + #elif defined(__unix__) + name = "Unix"; + #else + return "Unknown"; + #endif + + #if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__) + name += " (64-bit)"; + #else + name += " (32-bit)"; + #endif + + return name; + } + + std::string getCompilerName() + { + #if defined(__INTEL_COMPILER) + int mayor = __INTEL_COMPILER / 100 % 100; + int minor = __INTEL_COMPILER % 100; + std::string version = "Intel Compiler "; + version += toString(mayor); + version += "." + toString(minor); + #if defined(__INTEL_COMPILER_UPDATE) + version += "." + toString(__INTEL_COMPILER_UPDATE); + #endif + return version; + #elif defined(__clang__) + return "Clang " __clang_version__; + #elif defined(__GNUC__) + return "GCC " __VERSION__; + #elif defined(_MSC_VER) + std::string version = toString(_MSC_FULL_VER); + version.insert(4, "."); + version.insert(9, "."); + version.insert(2, "."); + return "Visual C++ Compiler " + version; + #else + return "Unknown"; + #endif + } + + std::string getBuildName() + { + #if defined(NDEBUG) + return "Release"; + #else + return "Debug"; + #endif + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.h b/thirdparty/oidn/common/platform.h new file mode 100644 index 0000000000..205ac8981d --- /dev/null +++ b/thirdparty/oidn/common/platform.h @@ -0,0 +1,131 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX + #include +#elif defined(__APPLE__) + #include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "include/OpenImageDenoise/oidn.hpp" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Macros + // ---------------------------------------------------------------------------- + + #if defined(_WIN32) + // Windows + #if !defined(__noinline) + #define __noinline __declspec(noinline) + #endif + #else + // Unix + #if !defined(__forceinline) + #define __forceinline inline __attribute__((always_inline)) + #endif + #if !defined(__noinline) + #define __noinline __attribute__((noinline)) + #endif + #endif + + #ifndef UNUSED + #define UNUSED(x) ((void)x) + #endif + #ifndef MAYBE_UNUSED + #define MAYBE_UNUSED(x) UNUSED(x) + #endif + + // ---------------------------------------------------------------------------- + // Error handling and debugging + // ---------------------------------------------------------------------------- + + struct Verbose + { + int verbose; + + Verbose(int v = 0) : verbose(v) {} + __forceinline bool isVerbose(int v = 1) const { return v <= verbose; } + }; + + #define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; } + #define OIDN_FATAL(message) throw std::runtime_error(message); + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + using std::min; + using std::max; + + template + __forceinline T clamp(const T& value, const T& minValue, const T& maxValue) + { + return min(max(value, minValue), maxValue); + } + + void* alignedMalloc(size_t size, size_t alignment); + void alignedFree(void* ptr); + + template + inline std::string toString(const T& a) + { + std::stringstream sm; + sm << a; + return sm.str(); + } + +#if defined(__APPLE__) + template + bool getSysctl(const char* name, T& value) + { + int64_t result = 0; + size_t size = sizeof(result); + + if (sysctlbyname(name, &result, &size, nullptr, 0) != 0) + return false; + + value = T(result); + return true; + } +#endif + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName(); + std::string getCompilerName(); + std::string getBuildName(); + +} // namespace oidn diff --git a/thirdparty/oidn/common/ref.h b/thirdparty/oidn/common/ref.h new file mode 100644 index 0000000000..de44603af2 --- /dev/null +++ b/thirdparty/oidn/common/ref.h @@ -0,0 +1,163 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +namespace oidn { + + class RefCount + { + private: + std::atomic count; + + public: + __forceinline RefCount(int count = 0) noexcept : count(count) {} + + __forceinline size_t incRef() noexcept + { + return count.fetch_add(1) + 1; + } + + __forceinline size_t decRef() + { + const size_t newCount = decRefKeep(); + if (newCount == 0) + destroy(); + return newCount; + } + + __forceinline size_t decRefKeep() noexcept + { + return count.fetch_add(-1) - 1; + } + + __forceinline void destroy() + { + delete this; + } + + protected: + // Disable copying + RefCount(const RefCount&) = delete; + RefCount& operator =(const RefCount&) = delete; + + virtual ~RefCount() noexcept = default; + }; + + template + class Ref + { + private: + T* ptr; + + public: + __forceinline Ref() noexcept : ptr(nullptr) {} + __forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {} + __forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); } + __forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; } + __forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + template + __forceinline Ref(const Ref& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); } + + template + __forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + __forceinline ~Ref() { if (ptr) ptr->decRef(); } + + __forceinline Ref& operator =(const Ref& other) + { + if (other.ptr) + other.ptr->incRef(); + if (ptr) + ptr->decRef(); + ptr = other.ptr; + return *this; + } + + __forceinline Ref& operator =(Ref&& other) + { + if (ptr) + ptr->decRef(); + ptr = other.ptr; + other.ptr = nullptr; + return *this; + } + + __forceinline Ref& operator =(T* other) + { + if (other) + other->incRef(); + if (ptr) + ptr->decRef(); + ptr = other; + return *this; + } + + __forceinline Ref& operator =(std::nullptr_t) + { + if (ptr) + ptr->decRef(); + ptr = nullptr; + return *this; + } + + __forceinline operator bool() const noexcept { return ptr != nullptr; } + + __forceinline T& operator *() const noexcept { return *ptr; } + __forceinline T* operator ->() const noexcept { return ptr; } + + __forceinline T* get() const noexcept { return ptr; } + + __forceinline T* detach() noexcept + { + T* res = ptr; + ptr = nullptr; + return res; + } + }; + + template __forceinline bool operator < (const Ref& a, const Ref& b) noexcept { return a.ptr < b.ptr; } + + template __forceinline bool operator ==(const Ref& a, std::nullptr_t) noexcept { return a.ptr == nullptr; } + template __forceinline bool operator ==(std::nullptr_t, const Ref& b) noexcept { return nullptr == b.ptr; } + template __forceinline bool operator ==(const Ref& a, const Ref& b) noexcept { return a.ptr == b.ptr; } + + template __forceinline bool operator !=(const Ref& a, std::nullptr_t) noexcept { return a.ptr != nullptr; } + template __forceinline bool operator !=(std::nullptr_t, const Ref& b) noexcept { return nullptr != b.ptr; } + template __forceinline bool operator !=(const Ref& a, const Ref& b) noexcept { return a.ptr != b.ptr; } + + template + __forceinline Ref makeRef(Args&&... args) + { + return Ref(new T(std::forward(args)...)); + } + + template + __forceinline Ref staticRefCast(const Ref& a) + { + return Ref(static_cast(a.get())); + } + + template + __forceinline Ref dynamicRefCast(const Ref& a) + { + return Ref(dynamic_cast(a.get())); + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.cpp b/thirdparty/oidn/common/tensor.cpp new file mode 100644 index 0000000000..0249f2e141 --- /dev/null +++ b/thirdparty/oidn/common/tensor.cpp @@ -0,0 +1,83 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "exception.h" +#include "tensor.h" + +namespace oidn { + + std::map parseTensors(void* buffer) + { + char* input = (char*)buffer; + + // Parse the magic value + const int magic = *(unsigned short*)input; + if (magic != 0x41D7) + throw Exception(Error::InvalidOperation, "invalid tensor archive"); + input += sizeof(unsigned short); + + // Parse the version + const int majorVersion = *(unsigned char*)input++; + const int minorVersion = *(unsigned char*)input++; + UNUSED(minorVersion); + if (majorVersion > 1) + throw Exception(Error::InvalidOperation, "unsupported tensor archive version"); + + // Parse the number of tensors + const int numTensors = *(int*)input; + input += sizeof(int); + + // Parse the tensors + std::map tensorMap; + for (int i = 0; i < numTensors; ++i) + { + Tensor tensor; + + // Parse the name + const int nameLen = *(unsigned char*)input++; + std::string name(input, nameLen); + input += nameLen; + + // Parse the number of dimensions + const int ndims = *(unsigned char*)input++; + + // Parse the shape of the tensor + tensor.dims.resize(ndims); + for (int i = 0; i < ndims; ++i) + tensor.dims[i] = ((int*)input)[i]; + input += ndims * sizeof(int); + + // Parse the format of the tensor + tensor.format = std::string(input, input + ndims); + input += ndims; + + // Parse the data type of the tensor + const char type = *(unsigned char*)input++; + if (type != 'f') // only float32 is supported + throw Exception(Error::InvalidOperation, "unsupported tensor data type"); + + // Skip the data + tensor.data = (float*)input; + input += tensor.size() * sizeof(float); + + // Add the tensor to the map + tensorMap.emplace(name, std::move(tensor)); + } + + return tensorMap; + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.h b/thirdparty/oidn/common/tensor.h new file mode 100644 index 0000000000..48e7d1123d --- /dev/null +++ b/thirdparty/oidn/common/tensor.h @@ -0,0 +1,66 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include +#include + +namespace oidn { + + template + using shared_vector = std::shared_ptr>; + + // Generic tensor + struct Tensor + { + float* data; + std::vector dims; + std::string format; + shared_vector buffer; // optional, only for reference counting + + __forceinline Tensor() : data(nullptr) {} + + __forceinline Tensor(const std::vector& dims, const std::string& format) + : dims(dims), + format(format) + { + buffer = std::make_shared>(size() * sizeof(float)); + data = (float*)buffer->data(); + } + + __forceinline operator bool() const { return data != nullptr; } + + __forceinline int ndims() const { return (int)dims.size(); } + + // Returns the number of values + __forceinline size_t size() const + { + size_t size = 1; + for (int i = 0; i < ndims(); ++i) + size *= dims[i]; + return size; + } + + __forceinline float& operator [](size_t i) { return data[i]; } + __forceinline const float& operator [](size_t i) const { return data[i]; } + }; + + // Parses tensors from a buffer + std::map parseTensors(void* buffer); + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.cpp b/thirdparty/oidn/common/thread.cpp new file mode 100644 index 0000000000..48c489c57b --- /dev/null +++ b/thirdparty/oidn/common/thread.cpp @@ -0,0 +1,297 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#if defined(_MSC_VER) + #pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned +#endif + +#if defined(__APPLE__) + #include + #include +#endif + +#include "thread.h" +#include + +namespace oidn { + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + HMODULE hLib = GetModuleHandle(TEXT("kernel32")); + pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx"); + pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity"); + + if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity) + { + // Get logical processor information + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr; + DWORD bufferSize = 0; + + // First call the function with an empty buffer to get the required buffer size + BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + return; + } + + // Allocate the buffer + buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize); + if (!buffer) + { + OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed"); + return; + } + + // Call again the function but now with the properly sized buffer + result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (!result) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + free(buffer); + return; + } + + // Iterate over the logical processor information structures + // There should be one structure for each physical core + char* ptr = (char*)buffer; + while (ptr < (char*)buffer + bufferSize) + { + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr; + if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0) + { + // Iterate over the groups + int numThreads = 0; + for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group) + { + GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group]; + while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore)) + { + // Extract the next set bit/thread from the mask + GROUP_AFFINITY threadAffinity = coreAffinity; + threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask; + + // Push the affinity for this thread + affinities.push_back(threadAffinity); + oldAffinities.push_back(threadAffinity); + numThreads++; + + // Remove this bit/thread from the mask + coreAffinity.Mask ^= threadAffinity.Mask; + } + } + } + + // Next structure + ptr += item->Size; + } + + // Free the buffer + free(buffer); + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Save the current affinity and set the new one + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex])) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Restore the original affinity + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr)) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + std::vector threadIds; + + // Parse the thread/CPU topology + for (int cpuId = 0; ; cpuId++) + { + std::fstream fs; + std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list"); + fs.open(cpu.c_str(), std::fstream::in); + if (fs.fail()) break; + + int i; + int j = 0; + while ((j < numThreadsPerCore) && (fs >> i)) + { + if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; })) + threadIds.push_back(i); + + if (fs.peek() == ',') + fs.ignore(); + j++; + } + + fs.close(); + } + + #if 0 + for (size_t i = 0; i < thread_ids.size(); ++i) + std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl; + #endif + + // Create the affinity structures + affinities.resize(threadIds.size()); + oldAffinities.resize(threadIds.size()); + + for (size_t i = 0; i < threadIds.size(); ++i) + { + cpu_set_t affinity; + CPU_ZERO(&affinity); + CPU_SET(threadIds[i], &affinity); + + affinities[i] = affinity; + oldAffinities[i] = affinity; + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Save the current affinity + if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + { + OIDN_WARNING("pthread_getaffinity_np failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Restore the original affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + // Query the thread/CPU topology + int numPhysicalCpus; + int numLogicalCpus; + + if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus)) + { + OIDN_WARNING("sysctlbyname failed"); + return; + } + + if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1)) + return; // this shouldn't happen + const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus; + + // Create the affinity structures + // macOS doesn't support binding a thread to a specific core, but we can at least group threads which + // should be on the same core together + for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1! + { + thread_affinity_policy affinity; + affinity.affinity_tag = core; + + for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread) + { + affinities.push_back(affinity); + oldAffinities.push_back(affinity); + } + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Save the current affinity + mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT; + boolean_t getDefault = FALSE; + if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS) + { + OIDN_WARNING("thread_policy_get failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Restore the original affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.h b/thirdparty/oidn/common/thread.h new file mode 100644 index 0000000000..2c731367da --- /dev/null +++ b/thirdparty/oidn/common/thread.h @@ -0,0 +1,202 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +#if !defined(_WIN32) + #include + #include + #if defined(__APPLE__) + #include + #endif +#endif + +#include +#include + +namespace oidn { + + // -------------------------------------------------------------------------- + // ThreadLocal + // -------------------------------------------------------------------------- + + // Wrapper which makes any variable thread-local + template + class ThreadLocal : public Verbose + { + private: + #if defined(_WIN32) + DWORD key; + #else + pthread_key_t key; + #endif + + std::vector instances; + std::mutex mutex; + + public: + ThreadLocal(int verbose = 0) + : Verbose(verbose) + { + #if defined(_WIN32) + key = TlsAlloc(); + if (key == TLS_OUT_OF_INDEXES) + OIDN_FATAL("TlsAlloc failed"); + #else + if (pthread_key_create(&key, nullptr) != 0) + OIDN_FATAL("pthread_key_create failed"); + #endif + } + + ~ThreadLocal() + { + std::lock_guard lock(mutex); + for (T* ptr : instances) + delete ptr; + + #if defined(_WIN32) + if (!TlsFree(key)) + OIDN_WARNING("TlsFree failed"); + #else + if (pthread_key_delete(key) != 0) + OIDN_WARNING("pthread_key_delete failed"); + #endif + } + + T& get() + { + #if defined(_WIN32) + T* ptr = (T*)TlsGetValue(key); + #else + T* ptr = (T*)pthread_getspecific(key); + #endif + + if (ptr) + return *ptr; + + ptr = new T; + std::lock_guard lock(mutex); + instances.push_back(ptr); + + #if defined(_WIN32) + if (!TlsSetValue(key, ptr)) + OIDN_FATAL("TlsSetValue failed"); + #else + if (pthread_setspecific(key, ptr) != 0) + OIDN_FATAL("pthread_setspecific failed"); + #endif + + return *ptr; + } + }; + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP, + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, + PDWORD); + + typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE, + CONST GROUP_AFFINITY*, + PGROUP_AFFINITY); + + GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr; + SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr; + + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/timer.h b/thirdparty/oidn/common/timer.h new file mode 100644 index 0000000000..62aaaa1c33 --- /dev/null +++ b/thirdparty/oidn/common/timer.h @@ -0,0 +1,49 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include + +namespace oidn { + + class Timer + { + private: + using clock = std::chrono::high_resolution_clock; + + std::chrono::time_point start; + + public: + Timer() + { + reset(); + } + + void reset() + { + start = clock::now(); + } + + double query() const + { + auto end = clock::now(); + return std::chrono::duration_cast>(end - start).count(); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/api.cpp b/thirdparty/oidn/core/api.cpp new file mode 100644 index 0000000000..7353fe4e25 --- /dev/null +++ b/thirdparty/oidn/core/api.cpp @@ -0,0 +1,408 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#ifdef _WIN32 +# define OIDN_API extern "C" __declspec(dllexport) +#else +# define OIDN_API extern "C" __attribute__ ((visibility ("default"))) +#endif + +// Locks the device that owns the specified object +// Use *only* inside OIDN_TRY/CATCH! +#define OIDN_LOCK(obj) \ + std::lock_guard lock(obj->getDevice()->getMutex()); + +// Try/catch for converting exceptions to errors +#define OIDN_TRY \ + try { + +#define OIDN_CATCH(obj) \ + } catch (Exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \ + } catch (std::bad_alloc&) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + } catch (mkldnn::error& e) { \ + if (e.status == mkldnn_out_of_memory) \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + else \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \ + } catch (std::exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \ + } catch (...) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \ + } + +#include "device.h" +#include "filter.h" +#include + +namespace oidn { + + namespace + { + __forceinline void checkHandle(void* handle) + { + if (handle == nullptr) + throw Exception(Error::InvalidArgument, "invalid handle"); + } + + template + __forceinline void retainObject(T* obj) + { + if (obj) + { + obj->incRef(); + } + else + { + OIDN_TRY + checkHandle(obj); + OIDN_CATCH(obj) + } + } + + template + __forceinline void releaseObject(T* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + OIDN_LOCK(obj); + obj->destroy(); + OIDN_CATCH(obj) + } + } + + template<> + __forceinline void releaseObject(Device* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + // Do NOT lock the device because it owns the mutex + obj->destroy(); + OIDN_CATCH(obj) + } + } + } + + OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type) + { + Ref device = nullptr; + OIDN_TRY + if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT) + device = makeRef(); + else + throw Exception(Error::InvalidArgument, "invalid device type"); + OIDN_CATCH(device) + return (OIDNDevice)device.detach(); + } + + OIDN_API void oidnRetainDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + retainObject(device); + } + + OIDN_API void oidnReleaseDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + releaseObject(device); + } + + OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return false; + } + + OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return 0; + } + + OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->setErrorFunction((ErrorFunction)func, userPtr); + OIDN_CATCH(device) + } + + OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage) + { + Device* device = (Device*)hDevice; + OIDN_TRY + return (OIDNError)Device::getError(device, outMessage); + OIDN_CATCH(device) + if (outMessage) *outMessage = ""; + return OIDN_ERROR_UNKNOWN; + } + + OIDN_API void oidnCommitDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->commit(); + OIDN_CATCH(device) + } + + OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref buffer = device->newBuffer(byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref buffer = device->newBuffer(ptr, byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + retainObject(buffer); + } + + OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + releaseObject(buffer); + } + + OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->map(byteOffset, byteSize); + OIDN_CATCH(buffer) + return nullptr; + } + + OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->unmap(mappedPtr); + OIDN_CATCH(buffer) + } + + OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref filter = device->newFilter(type); + return (OIDNFilter)filter.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + retainObject(filter); + } + + OIDN_API void oidnReleaseFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + releaseObject(filter); + } + + OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name, + OIDNBuffer hBuffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + checkHandle(hBuffer); + OIDN_LOCK(filter); + Ref buffer = (Buffer*)hBuffer; + if (buffer->getDevice() != filter->getDevice()) + throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices"); + Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, int(value)); + OIDN_CATCH(filter) + } + + OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return false; + } + + OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, value); + OIDN_CATCH(filter) + } + + OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1f(name, value); + OIDN_CATCH(filter) + } + + OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1f(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->setProgressMonitorFunction(func, userPtr); + OIDN_CATCH(filter) + } + + OIDN_API void oidnCommitFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->commit(); + OIDN_CATCH(filter) + } + + OIDN_API void oidnExecuteFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->execute(); + OIDN_CATCH(filter) + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.cpp b/thirdparty/oidn/core/autoencoder.cpp new file mode 100644 index 0000000000..8ae2421fa6 --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.cpp @@ -0,0 +1,519 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "autoencoder.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter + // -------------------------------------------------------------------------- + + AutoencoderFilter::AutoencoderFilter(const Ref& device) + : Filter(device) + { + } + + void AutoencoderFilter::setImage(const std::string& name, const Image& data) + { + if (name == "color") + color = data; + else if (name == "albedo") + albedo = data; + else if (name == "normal") + normal = data; + else if (name == "output") + output = data; + + dirty = true; + } + + void AutoencoderFilter::set1i(const std::string& name, int value) + { + if (name == "hdr") + hdr = value; + else if (name == "srgb") + srgb = value; + else if (name == "maxMemoryMB") + maxMemoryMB = value; + + dirty = true; + } + + int AutoencoderFilter::get1i(const std::string& name) + { + if (name == "hdr") + return hdr; + else if (name == "srgb") + return srgb; + else if (name == "maxMemoryMB") + return maxMemoryMB; + else if (name == "alignment") + return alignment; + else if (name == "overlap") + return overlap; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::set1f(const std::string& name, float value) + { + if (name == "hdrScale") + hdrScale = value; + + dirty = true; + } + + float AutoencoderFilter::get1f(const std::string& name) + { + if (name == "hdrScale") + return hdrScale; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::commit() + { + if (!dirty) + return; + + { + if (mayiuse(avx512_common)) + net = buildNet<16>(); + else + net = buildNet<8>(); + } + + dirty = false; + } + + void AutoencoderFilter::execute() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the filter are not committed"); + + if (!net) + return; + + { + Progress progress; + progress.func = progressFunc; + progress.userPtr = progressUserPtr; + progress.taskCount = tileCountH * tileCountW; + + // Iterate over the tiles + int tileIndex = 0; + + for (int i = 0; i < tileCountH; ++i) + { + const int h = i * (tileH - 2*overlap); // input tile position (including overlap) + const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top + const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom + const int tileH1 = min(H - h, tileH); // input tile size (including overlap) + const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size + const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer + + for (int j = 0; j < tileCountW; ++j) + { + const int w = j * (tileW - 2*overlap); // input tile position (including overlap) + const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left + const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right + const int tileW1 = min(W - w, tileW); // input tile size (including overlap) + const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size + const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer + + // Set the input tile + inputReorder->setTile(h, w, + alignOffsetH, alignOffsetW, + tileH1, tileW1); + + // Set the output tile + outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW, + h + overlapBeginH, w + overlapBeginW, + tileH2, tileW2); + + //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2); + + // Denoise the tile + net->execute(progress, tileIndex); + + // Next tile + tileIndex++; + } + } + } + } + + void AutoencoderFilter::computeTileSize() + { + const int minTileSize = 3*overlap; + const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8; + const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel; + + tileCountH = 1; + tileCountW = 1; + tileH = roundUp(H, alignment); + tileW = roundUp(W, alignment); + + // Divide the image into tiles until the tile size gets below the threshold + while (int64_t(tileH) * tileW > maxTilePixels) + { + if (tileH > minTileSize && tileH > tileW) + { + tileCountH++; + tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize); + } + else if (tileW > minTileSize) + { + tileCountW++; + tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize); + } + else + break; + } + + // Compute the final number of tiles + tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1; + tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1; + + if (device->isVerbose(2)) + { + std::cout << "Tile size : " << tileW << "x" << tileH << std::endl; + std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl; + } + } + + template + std::shared_ptr AutoencoderFilter::buildNet() + { + H = color.height; + W = color.width; + + // Configure the network + int inputC; + void* weightPtr; + + if (srgb && hdr) + throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time"); + + if (color && !albedo && !normal && weightData.hdr) + { + inputC = 3; + weightPtr = hdr ? weightData.hdr : weightData.ldr; + } + else if (color && albedo && !normal && weightData.hdr_alb) + { + inputC = 6; + weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb; + } + else if (color && albedo && normal && weightData.hdr_alb_nrm) + { + inputC = 9; + weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm; + } + else + { + throw Exception(Error::InvalidOperation, "unsupported combination of input features"); + } + + if (!output) + throw Exception(Error::InvalidOperation, "output image not specified"); + + if ((color.format != Format::Float3) + || (albedo && albedo.format != Format::Float3) + || (normal && normal.format != Format::Float3) + || (output.format != Format::Float3)) + throw Exception(Error::InvalidOperation, "unsupported image format"); + + if ((albedo && (albedo.width != W || albedo.height != H)) + || (normal && (normal.width != W || normal.height != H)) + || (output.width != W || output.height != H)) + throw Exception(Error::InvalidOperation, "image size mismatch"); + + // Compute the tile size + computeTileSize(); + + // If the image size is zero, there is nothing else to do + if (H <= 0 || W <= 0) + return nullptr; + + // Parse the weights + const auto weightMap = parseTensors(weightPtr); + + // Create the network + std::shared_ptr> net = std::make_shared>(device, weightMap); + + // Compute the tensor sizes + const auto inputDims = memory::dims({1, inputC, tileH, tileW}); + const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0 + + const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0 + const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1 + const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1 + const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0 + const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2 + const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0 + const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3 + const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0 + const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4 + const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0 + const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1 + const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4 + const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims); + const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0 + const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1 + const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3 + const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims); + const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0 + const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1 + const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2 + const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims); + const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0 + const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1 + const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1 + const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims); + const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0 + const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1 + const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0 + const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims); + const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0 + const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1 + const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0 + + const auto outputDims = memory::dims({1, 3, tileH, tileW}); + + // Allocate two temporary ping-pong buffers to decrease memory usage + const auto temp0Dims = getMaxTensorDims({ + conv1Dims, + conv2Dims, + conv3Dims, + conv4Dims, + conv5Dims, + conv6Dims, + conv7Dims, + conv8Dims, + conv9Dims, + conv10Dims, + conv11Dims + }); + + const auto temp1Dims = getMaxTensorDims({ + conv1bDims, + pool5Dims, + conv6bDims, + conv7bDims, + conv8bDims, + conv9bDims, + conv10bDims, + }); + + auto temp0 = net->allocTensor(temp0Dims); + auto temp1 = net->allocTensor(temp1Dims); + + // Allocate enough memory to hold the concat outputs. Then use the first + // half to hold the previous conv output and the second half to hold the + // pool/orig image output. This works because everything is C dimension + // outermost, padded to K floats, and all the concats are on the C dimension. + auto concat0Dst = net->allocTensor(concat0Dims); + auto concat1Dst = net->allocTensor(concat1Dims); + auto concat2Dst = net->allocTensor(concat2Dims); + auto concat3Dst = net->allocTensor(concat3Dims); + auto concat4Dst = net->allocTensor(concat4Dims); + + // Transfer function + std::shared_ptr transferFunc = makeTransferFunc(); + + // Autoexposure + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + { + if (isnan(hdrScale)) + net->addAutoexposure(color, tf); + else + tf->setExposure(hdrScale); + } + + // Input reorder + auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims); + inputReorder = net->addInputReorder(color, albedo, normal, + transferFunc, + alignment, inputReorderDst); + + // conv1 + auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0); + + // conv1b + auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1); + + // pool1 + // Adjust pointer for pool1 to eliminate concat1 + auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims); + auto pool1 = net->addPool(conv1b->getDst(), pool1Dst); + + // conv2 + auto conv2 = net->addConv("conv2", pool1->getDst(), temp0); + + // pool2 + // Adjust pointer for pool2 to eliminate concat2 + auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims); + auto pool2 = net->addPool(conv2->getDst(), pool2Dst); + + // conv3 + auto conv3 = net->addConv("conv3", pool2->getDst(), temp0); + + // pool3 + // Adjust pointer for pool3 to eliminate concat3 + auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims); + auto pool3 = net->addPool(conv3->getDst(), pool3Dst); + + // conv4 + auto conv4 = net->addConv("conv4", pool3->getDst(), temp0); + + // pool4 + // Adjust pointer for pool4 to eliminate concat4 + auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims); + auto pool4 = net->addPool(conv4->getDst(), pool4Dst); + + // conv5 + auto conv5 = net->addConv("conv5", pool4->getDst(), temp0); + + // pool5 + auto pool5 = net->addPool(conv5->getDst(), temp1); + + // upsample4 + auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst); + auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst); + + // conv6 + auto conv6 = net->addConv("conv6", concat4Dst, temp0); + + // conv6b + auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1); + + // upsample3 + auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst); + auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst); + + // conv7 + auto conv7 = net->addConv("conv7", concat3Dst, temp0); + + // conv7b + auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1); + + // upsample2 + auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst); + auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst); + + // conv8 + auto conv8 = net->addConv("conv8", concat2Dst, temp0); + + // conv8b + auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1); + + // upsample1 + auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst); + auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst); + + // conv9 + auto conv9 = net->addConv("conv9", concat1Dst, temp0); + + // conv9b + auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1); + + // upsample0 + auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst); + auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst); + + // conv10 + auto conv10 = net->addConv("conv10", concat0Dst, temp0); + + // conv10b + auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1); + + // conv11 + auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */); + + // Output reorder + outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output); + + net->finalize(); + return net; + } + + std::shared_ptr AutoencoderFilter::makeTransferFunc() + { + if (hdr) + return std::make_shared(); + else if (srgb) + return std::make_shared(); + else + return std::make_shared(); + } + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + // -------------------------------------------------------------------------- + // RTFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // LDR + extern unsigned char rt_ldr[]; // color + extern unsigned char rt_ldr_alb[]; // color, albedo + extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal + + // HDR + extern unsigned char rt_hdr[]; // color + extern unsigned char rt_hdr_alb[]; // color, albedo + extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal + } + + RTFilter::RTFilter(const Ref& device) + : AutoencoderFilter(device) + { + weightData.ldr = weights::rt_ldr; + weightData.ldr_alb = weights::rt_ldr_alb; + weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm; + weightData.hdr = weights::rt_hdr; + weightData.hdr_alb = weights::rt_hdr_alb; + weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; + } +#endif + + // -------------------------------------------------------------------------- + // RTLightmapFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // HDR + extern unsigned char rtlightmap_hdr[]; // color + } + + RTLightmapFilter::RTLightmapFilter(const Ref& device) + : AutoencoderFilter(device) + { + weightData.hdr = weights::rtlightmap_hdr; + + hdr = true; + } + + std::shared_ptr RTLightmapFilter::makeTransferFunc() + { + return std::make_shared(); + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.h b/thirdparty/oidn/core/autoencoder.h new file mode 100644 index 0000000000..97432f2bbd --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.h @@ -0,0 +1,116 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "filter.h" +#include "network.h" +#include "transfer_function.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter - Direct-predicting autoencoder + // -------------------------------------------------------------------------- + + class AutoencoderFilter : public Filter + { + protected: + static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary) + static constexpr int receptiveField = 222; // receptive field in pixels + static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels + + static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage + static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8 + static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16 + + Image color; + Image albedo; + Image normal; + Image output; + bool hdr = false; + float hdrScale = std::numeric_limits::quiet_NaN(); + bool srgb = false; + int maxMemoryMB = 6000; // approximate maximum memory usage in MBs + + int H = 0; // image height + int W = 0; // image width + int tileH = 0; // tile height + int tileW = 0; // tile width + int tileCountH = 1; // number of tiles in H dimension + int tileCountW = 1; // number of tiles in W dimension + + std::shared_ptr net; + std::shared_ptr inputReorder; + std::shared_ptr outputReorder; + + struct + { + void* ldr = nullptr; + void* ldr_alb = nullptr; + void* ldr_alb_nrm = nullptr; + void* hdr = nullptr; + void* hdr_alb = nullptr; + void* hdr_alb_nrm = nullptr; + } weightData; + + explicit AutoencoderFilter(const Ref& device); + virtual std::shared_ptr makeTransferFunc(); + + public: + void setImage(const std::string& name, const Image& data) override; + void set1i(const std::string& name, int value) override; + int get1i(const std::string& name) override; + void set1f(const std::string& name, float value) override; + float get1f(const std::string& name) override; + + void commit() override; + void execute() override; + + private: + void computeTileSize(); + + template + std::shared_ptr buildNet(); + + bool isCommitted() const { return bool(net); } + }; + + // -------------------------------------------------------------------------- + // RTFilter - Generic ray tracing denoiser + // -------------------------------------------------------------------------- + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + class RTFilter : public AutoencoderFilter + { + public: + explicit RTFilter(const Ref& device); + }; +#endif + + // -------------------------------------------------------------------------- + // RTLightmapFilter - Ray traced lightmap denoiser + // -------------------------------------------------------------------------- + + class RTLightmapFilter : public AutoencoderFilter + { + public: + explicit RTLightmapFilter(const Ref& device); + std::shared_ptr makeTransferFunc() override; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/buffer.h b/thirdparty/oidn/core/buffer.h new file mode 100644 index 0000000000..b95109152e --- /dev/null +++ b/thirdparty/oidn/core/buffer.h @@ -0,0 +1,75 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" + +namespace oidn { + + class Device; + + // Buffer which may or may not own its data + class Buffer : public RefCount + { + private: + char* ptr; + size_t byteSize; + bool shared; + Ref device; + + public: + __forceinline Buffer(const Ref& device, size_t size) + : ptr((char*)alignedMalloc(size, 64)), + byteSize(size), + shared(false), + device(device) {} + + __forceinline Buffer(const Ref& device, void* data, size_t size) + : ptr((char*)data), + byteSize(size), + shared(true), + device(device) + { + if (data == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + } + + __forceinline ~Buffer() + { + if (!shared) + alignedFree(ptr); + } + + __forceinline char* data() { return ptr; } + __forceinline const char* data() const { return ptr; } + __forceinline size_t size() const { return byteSize; } + + void* map(size_t offset, size_t size) + { + if (offset + size > byteSize) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + + return ptr + offset; + } + + void unmap(void* mappedPtr) {} + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/common.h b/thirdparty/oidn/core/common.h new file mode 100644 index 0000000000..6c87f377bc --- /dev/null +++ b/thirdparty/oidn/core/common.h @@ -0,0 +1,133 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +#include "mkl-dnn/include/mkldnn.hpp" +#include "mkl-dnn/include/mkldnn_debug.h" +#include "mkl-dnn/src/common/mkldnn_thread.hpp" +#include "mkl-dnn/src/common/type_helpers.hpp" +#include "mkl-dnn/src/cpu/jit_generator.hpp" + +#include "common/ref.h" +#include "common/exception.h" +#include "common/thread.h" +#include "math.h" + +namespace oidn { + + using namespace mkldnn; + using namespace mkldnn::impl::cpu; + using mkldnn::impl::parallel_nd; + using mkldnn::impl::memory_desc_matches_tag; + + + inline size_t getFormatBytes(Format format) + { + switch (format) + { + case Format::Undefined: return 1; + case Format::Float: return sizeof(float); + case Format::Float2: return sizeof(float)*2; + case Format::Float3: return sizeof(float)*3; + case Format::Float4: return sizeof(float)*4; + } + assert(0); + return 0; + } + + + inline memory::dims getTensorDims(const std::shared_ptr& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]); + } + + inline memory::data_type getTensorType(const std::shared_ptr& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::data_type(desc.data_type); + } + + // Returns the number of values in a tensor + inline size_t getTensorSize(const memory::dims& dims) + { + size_t res = 1; + for (int i = 0; i < (int)dims.size(); ++i) + res *= dims[i]; + return res; + } + + inline memory::dims getMaxTensorDims(const std::vector& dims) + { + memory::dims result; + size_t maxSize = 0; + + for (const auto& d : dims) + { + const size_t size = getTensorSize(d); + if (size > maxSize) + { + result = d; + maxSize = size; + } + } + + return result; + } + + inline size_t getTensorSize(const std::shared_ptr& mem) + { + return getTensorSize(getTensorDims(mem)); + } + + + template + inline int getPadded(int dim) + { + return (dim + (K-1)) & ~(K-1); + } + + template + inline memory::dims getPadded_nchw(const memory::dims& dims) + { + assert(dims.size() == 4); + memory::dims padDims = dims; + padDims[1] = getPadded(dims[1]); // pad C + return padDims; + } + + + template + struct BlockedFormat; + + template<> + struct BlockedFormat<8> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o; + }; + + template<> + struct BlockedFormat<16> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.cpp b/thirdparty/oidn/core/device.cpp new file mode 100644 index 0000000000..0812624bb5 --- /dev/null +++ b/thirdparty/oidn/core/device.cpp @@ -0,0 +1,205 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "device.h" +#include "autoencoder.h" + +namespace oidn { + + thread_local Device::ErrorState Device::globalError; + + Device::Device() + { + if (!mayiuse(sse41)) + throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum"); + } + + Device::~Device() + { + } + + void Device::setError(Device* device, Error code, const std::string& message) + { + // Update the stored error only if the previous error was queried + if (device) + { + ErrorState& curError = device->error.get(); + + if (curError.code == Error::None) + { + curError.code = code; + curError.message = message; + } + + // Print the error message in verbose mode + if (device->isVerbose()) + std::cerr << "Error: " << message << std::endl; + + // Call the error callback function + ErrorFunction errorFunc; + void* errorUserPtr; + + { + std::lock_guard lock(device->mutex); + errorFunc = device->errorFunc; + errorUserPtr = device->errorUserPtr; + } + + if (errorFunc) + errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str()); + } + else + { + if (globalError.code == Error::None) + { + globalError.code = code; + globalError.message = message; + } + } + } + + Error Device::getError(Device* device, const char** outMessage) + { + // Return and clear the stored error code, but keep the error message so pointers to it will + // remain valid until the next getError call + if (device) + { + ErrorState& curError = device->error.get(); + const Error code = curError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : curError.message.c_str(); + curError.code = Error::None; + return code; + } + else + { + const Error code = globalError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : globalError.message.c_str(); + globalError.code = Error::None; + return code; + } + } + + void Device::setErrorFunction(ErrorFunction func, void* userPtr) + { + errorFunc = func; + errorUserPtr = userPtr; + } + + int Device::get1i(const std::string& name) + { + if (name == "numThreads") + return numThreads; + else if (name == "setAffinity") + return setAffinity; + else if (name == "verbose") + return verbose; + else if (name == "version") + return OIDN_VERSION; + else if (name == "versionMajor") + return OIDN_VERSION_MAJOR; + else if (name == "versionMinor") + return OIDN_VERSION_MINOR; + else if (name == "versionPatch") + return OIDN_VERSION_PATCH; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void Device::set1i(const std::string& name, int value) + { + if (name == "numThreads") + numThreads = value; + else if (name == "setAffinity") + setAffinity = value; + else if (name == "verbose") + { + verbose = value; + error.verbose = value; + } + + dirty = true; + } + + void Device::commit() + { + if (isCommitted()) + throw Exception(Error::InvalidOperation, "device can be committed only once"); + + // Create the task arena + const int maxNumThreads = 1; //affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency(); + numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads; + + dirty = false; + + if (isVerbose()) + print(); + } + + void Device::checkCommitted() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the device are not committed"); + } + + Ref Device::newBuffer(size_t byteSize) + { + checkCommitted(); + return makeRef(Ref(this), byteSize); + } + + Ref Device::newBuffer(void* ptr, size_t byteSize) + { + checkCommitted(); + return makeRef(Ref(this), ptr, byteSize); + } + + Ref Device::newFilter(const std::string& type) + { + checkCommitted(); + + if (isVerbose()) + std::cout << "Filter: " << type << std::endl; + + Ref filter; + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + if (type == "RT") + filter = makeRef(Ref(this)); +#endif + if (type == "RTLightmap") + filter = makeRef(Ref(this)); + else + throw Exception(Error::InvalidArgument, "unknown filter type"); + + return filter; + } + + void Device::print() + { + std::cout << std::endl; + + std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl; + std::cout << " Compiler: " << getCompilerName() << std::endl; + std::cout << " Build : " << getBuildName() << std::endl; + std::cout << " Platform: " << getPlatformName() << std::endl; + + std::cout << std::endl; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.h b/thirdparty/oidn/core/device.h new file mode 100644 index 0000000000..93a83eb731 --- /dev/null +++ b/thirdparty/oidn/core/device.h @@ -0,0 +1,78 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" + +namespace oidn { + + class Buffer; + class Filter; + + class Device : public RefCount, public Verbose + { + private: + // Thread-safety + std::mutex mutex; + + // Error handling + struct ErrorState + { + Error code = Error::None; + std::string message; + }; + + static thread_local ErrorState globalError; + ThreadLocal error; + ErrorFunction errorFunc = nullptr; + void* errorUserPtr = nullptr; + + // Parameters + int numThreads = 0; // autodetect by default + bool setAffinity = true; + + bool dirty = true; + + public: + Device(); + ~Device(); + + static void setError(Device* device, Error code, const std::string& message); + static Error getError(Device* device, const char** outMessage); + + void setErrorFunction(ErrorFunction func, void* userPtr); + + int get1i(const std::string& name); + void set1i(const std::string& name, int value); + + void commit(); + + Ref newBuffer(size_t byteSize); + Ref newBuffer(void* ptr, size_t byteSize); + Ref newFilter(const std::string& type); + + __forceinline Device* getDevice() { return this; } + __forceinline std::mutex& getMutex() { return mutex; } + + private: + bool isCommitted() const { return false; } + void checkCommitted(); + + void print(); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.cpp b/thirdparty/oidn/core/filter.cpp new file mode 100644 index 0000000000..ec1f10af87 --- /dev/null +++ b/thirdparty/oidn/core/filter.cpp @@ -0,0 +1,27 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "filter.h" + +namespace oidn { + + void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr) + { + progressFunc = func; + progressUserPtr = userPtr; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.h b/thirdparty/oidn/core/filter.h new file mode 100644 index 0000000000..935fa202f4 --- /dev/null +++ b/thirdparty/oidn/core/filter.h @@ -0,0 +1,52 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" +#include "image.h" + +namespace oidn { + + class Filter : public RefCount + { + protected: + Ref device; + + ProgressMonitorFunction progressFunc = nullptr; + void* progressUserPtr = nullptr; + + bool dirty = true; + + public: + explicit Filter(const Ref& device) : device(device) {} + + virtual void setImage(const std::string& name, const Image& data) = 0; + virtual void set1i(const std::string& name, int value) = 0; + virtual int get1i(const std::string& name) = 0; + virtual void set1f(const std::string& name, float value) = 0; + virtual float get1f(const std::string& name) = 0; + + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr); + + virtual void commit() = 0; + virtual void execute() = 0; + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/image.h b/thirdparty/oidn/core/image.h new file mode 100644 index 0000000000..748f49c4e5 --- /dev/null +++ b/thirdparty/oidn/core/image.h @@ -0,0 +1,111 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "buffer.h" + +namespace oidn { + + struct Image + { + static constexpr int maxSize = 65536; + + char* ptr; // pointer to the first pixel + int width; // width in number of pixels + int height; // height in number of pixels + size_t bytePixelStride; // pixel stride in number of *bytes* + size_t rowStride; // row stride in number of *pixel strides* + Format format; // pixel format + Ref buffer; // buffer containing the image data + + Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {} + + Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + if (ptr == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + + init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + } + + Image(const Ref& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + + if (byteOffset + height * rowStride * bytePixelStride > buffer->size()) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + } + + void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride) + { + assert(width >= 0); + assert(height >= 0); + if (width > maxSize || height > maxSize) + throw Exception(Error::InvalidArgument, "image size too large"); + + this->ptr = ptr; + this->width = width; + this->height = height; + + const size_t pixelSize = getFormatBytes(format); + if (inBytePixelStride != 0) + { + if (inBytePixelStride < pixelSize) + throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size"); + + this->bytePixelStride = inBytePixelStride; + } + else + { + this->bytePixelStride = pixelSize; + } + + if (inByteRowStride != 0) + { + if (inByteRowStride < width * this->bytePixelStride) + throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride"); + if (inByteRowStride % this->bytePixelStride != 0) + throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride"); + + this->rowStride = inByteRowStride / this->bytePixelStride; + } + else + { + this->rowStride = width; + } + + this->format = format; + } + + __forceinline char* get(int y, int x) + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + __forceinline const char* get(int y, int x) const + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + operator bool() const + { + return ptr != nullptr; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/input_reorder.h b/thirdparty/oidn/core/input_reorder.h new file mode 100644 index 0000000000..966856afe9 --- /dev/null +++ b/thirdparty/oidn/core/input_reorder.h @@ -0,0 +1,232 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Input reorder node + template + class InputReorderNode : public Node + { + private: + // Source + Image color; + Image albedo; + Image normal; + + // Destination + std::shared_ptr dst; + float* dstPtr; + int C2; + int H2; + int W2; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr transferFunc; + + public: + InputReorderNode(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& dst, + const std::shared_ptr& transferFunc) + : color(color), albedo(albedo), normal(normal), + dst(dst), + h1Begin(0), w1Begin(0), + H(color.height), W(color.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(dstDesc.ndims == 4); + assert(dstDesc.data_type == memory::data_type::f32); + assert(dstDesc.dims[0] == 1); + //assert(dstDesc.dims[1] >= getPadded(C1)); + + dstPtr = (float*)dst->get_data_handle(); + C2 = dstDesc.dims[1]; + H2 = dstDesc.dims[2]; + W2 = dstDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(H + h1Begin <= color.height); + assert(W + w1Begin <= color.width); + assert(H + h2Begin <= H2); + assert(W + w2Begin <= W2); + + parallel_nd(H2, [&](int h2) + { + const int h = h2 - h2Begin; + + if (h >= 0 && h < H) + { + const int h1 = h + h1Begin; + + // Zero pad + for (int w2 = 0; w2 < w2Begin; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Reorder + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + + int c = 0; + storeColor(h2, w2, c, (float*)color.get(h1, w1)); + if (albedo) + storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1)); + if (normal) + storeNormal(h2, w2, c, (float*)normal.get(h1, w1)); + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Zero pad + for (int w2 = W + w2Begin; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + else + { + // Zero pad + for (int w2 = 0; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + }); + } + + std::shared_ptr getDst() const override { return dst; } + + private: + // Stores a single value + __forceinline void store(int h, int w, int& c, float value) + { + // Destination is in nChwKc format + float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K); + *dst_c = value; + c++; + } + + // Stores a color + __forceinline void storeColor(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = maxSafe(x, 0.f); + + // Apply the transfer function + x = transferFunc->forward(x); + + // Store the value + store(h, w, c, x); + } + } + + // Stores an albedo + __forceinline void storeAlbedo(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = clampSafe(x, 0.f, 1.f); + + // Store the value + store(h, w, c, x); + } + } + + // Stores a normal + __forceinline void storeNormal(int h, int w, int& c, const float* values) + { + // Load the normal + float x = values[0]; + float y = values[1]; + float z = values[2]; + + // Compute the length of the normal + const float lengthSqr = sqr(x) + sqr(y) + sqr(z); + + // Normalize the normal and transform it to [0..1] + if (isfinite(lengthSqr)) + { + const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f; + + const float scale = invLength * 0.5f; + const float offset = 0.5f; + + x = x * scale + offset; + y = y * scale + offset; + z = z * scale + offset; + } + else + { + x = 0.f; + y = 0.f; + z = 0.f; + } + + // Store the normal + store(h, w, c, x); + store(h, w, c, y); + store(h, w, c, z); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/math.h b/thirdparty/oidn/core/math.h new file mode 100644 index 0000000000..a844ef0d1d --- /dev/null +++ b/thirdparty/oidn/core/math.h @@ -0,0 +1,78 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +namespace oidn { + + constexpr float minVectorLength = 1e-10f; + constexpr float minVectorLengthSqr = minVectorLength * minVectorLength; + + using std::log; + using std::log2; + using std::exp; + using std::exp2; + using std::pow; + using std::isfinite; + using std::isnan; + + __forceinline float sqr(float x) + { + return x * x; + } + + __forceinline float rcp(float x) + { + __m128 r = _mm_rcp_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x)))); + } + + __forceinline float rsqrt(float x) + { + __m128 r = _mm_rsqrt_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r), + _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r)))); + } + + __forceinline float maxSafe(float value, float minValue) + { + return isfinite(value) ? max(value, minValue) : minValue; + } + + __forceinline float clampSafe(float value, float minValue, float maxValue) + { + return isfinite(value) ? clamp(value, minValue, maxValue) : minValue; + } + + // Returns ceil(a / b) for non-negative integers + template + __forceinline constexpr Int ceilDiv(Int a, Int b) + { + //assert(a >= 0); + //assert(b > 0); + return (a + b - 1) / b; + } + + // Returns a rounded up to multiple of b + template + __forceinline constexpr Int roundUp(Int a, Int b) + { + return ceilDiv(a, b) * b; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.cpp b/thirdparty/oidn/core/network.cpp new file mode 100644 index 0000000000..4da32073cd --- /dev/null +++ b/thirdparty/oidn/core/network.cpp @@ -0,0 +1,434 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "network.h" +#include "upsample.h" +#include "weights_reorder.h" +#include + +namespace oidn { + + template + Network::Network(const Ref& device, const std::map& weightMap) + : device(device), + eng(engine::cpu, 0), + sm(eng), + weightMap(weightMap) + { + } + + template + void Network::execute(const Progress& progress, int taskIndex) + { + if (progress.func) + { + const double value = double(taskIndex) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + + for (size_t i = 0; i < nodes.size(); ++i) + { + nodes[i]->execute(sm); + + if (progress.func) + { + const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + } + } + + template + std::shared_ptr Network::allocTensor(const memory::dims& dims, + memory::format_tag format, + void* data) + { + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + if (data == nullptr) + { + const size_t bytes = getTensorSize(dims) * sizeof(float); + if (format == BlockedFormat::nChwKc) + activationAllocBytes += bytes; + totalAllocBytes += bytes; + + return std::make_shared(desc, eng); + } + else + { + return std::make_shared(desc, eng, data); + } + } + + template + std::shared_ptr Network::castTensor(const memory::dims& dims, + const std::shared_ptr& src, + size_t srcOffset, + memory::format_tag format) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(srcDesc.data_type == memory::data_type::f32); + assert(getTensorSize(src) >= srcOffset + getTensorSize(dims)); + + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + float* srcPtr = (float*)src->get_data_handle() + srcOffset; + return std::make_shared(desc, eng, srcPtr); + } + + template + std::shared_ptr Network::castTensor(const memory::dims& dims, + const std::shared_ptr& src, + const memory::dims& srcOffset) + { + return castTensor(dims, src, getTensorSize(srcOffset)); + } + + template + void Network::zeroTensor(const std::shared_ptr& dst) + { + assert(getTensorType(dst) == memory::data_type::f32); + memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float)); + } + + template + memory::dims Network::getInputReorderDims(const memory::dims& srcDims, int alignment) + { + memory::dims dstDims = srcDims; + dstDims[1] = getPadded(srcDims[1]); // round up C + dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H + dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W + return dstDims; + } + + template + std::shared_ptr Network::addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& transferFunc, + int alignment, + const std::shared_ptr& userDst) + { + assert(color); + int inputC = 3; + if (albedo) inputC += 3; + if (normal) inputC += 3; + + memory::dims srcDims = {1, inputC, color.height, color.width}; + memory::dims dstDims = getInputReorderDims(srcDims, alignment); + + // Allocate padded memory + auto dst = userDst; + if (!dst) + dst = allocTensor(dstDims); + + // Push node + std::shared_ptr node; + + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template + std::shared_ptr Network::addOutputReorder(const std::shared_ptr& src, + const std::shared_ptr& transferFunc, + const Image& output) + { + memory::dims srcDims = getTensorDims(src); + assert(srcDims[1] == K); + + // Push node + std::shared_ptr node; + + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getConvDims(const std::string& name, const memory::dims& srcDims) + { + auto b = weightMap[name + "/b"]; + memory::dims dstDims = srcDims; + dstDims[1] = getPadded(b.dims[0]); // dstDims[C] = getPadded(OC) + return dstDims; + } + + template + std::shared_ptr Network::addConv(const std::string& name, + const std::shared_ptr& src, + const std::shared_ptr& userDst, + bool relu) + { + const memory::dims strides = {1, 1}; + const memory::dims padding = {1, 1}; + + memory::dims srcDims = getTensorDims(src); + + // Get the weights + const auto& W = weightMap[name + "/W"]; + if (W.ndims() != 4 || W.format != "oihw") + throw Exception(Error::InvalidOperation, "invalid convolution weights"); + memory::dims weightsDims = W.dims; + auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data); + + // Pad the weights + memory::dims weightsPadDims = weightsDims; + weightsPadDims[1] = getPadded(weightsDims[1]); // IC + weightsPadDims[0] = getPadded(weightsDims[0]); // OC + assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC] + auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw); + WeightsReorderNode(userWeights, weightsPad).execute(sm); + + // Get the biases + const auto& b = weightMap[name + "/b"]; + if (b.ndims() != 1) + throw Exception(Error::InvalidOperation, "invalid convolution biases"); + memory::dims biasDims = b.dims; + + // Copy/pad the biases + memory::dims biasPadDims = {getPadded(biasDims[0])}; + auto bias = allocTensor(biasPadDims); + if (biasDims[0] != biasPadDims[0]) + memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float)); + memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float)); + + // Allocate memory for destination + memory::dims dstDims = srcDims; + dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC] + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create a convolution + // Let the convolution primitive choose the weights format + auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any); + + auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct; + auto convDesc = convolution_forward::desc( + prop_kind::forward_inference, convAlgo, + src->get_desc(), + weightsDesc, + bias->get_desc(), + dst->get_desc(), + strides, padding, padding, padding_kind::zero); + + // Incorporate relu + mkldnn::primitive_attr convAttr; + if (relu) + { + mkldnn::post_ops ops; + ops.append_eltwise( + 1.f, // scale factor, not used + algorithm::eltwise_relu, + 0.f, // max with + 0.f // unused + ); + convAttr.set_post_ops(ops); + } + convAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng); + + // Reorder the weights to the final format, if necessary + auto weights = weightsPad; + if (convPrimDesc.weights_desc() != weightsPad->get_desc()) + { + weights = std::make_shared(convPrimDesc.weights_desc(), eng); + ReorderNode(weightsPad, weights).execute(sm); + } + + // Create convolution node and add it to the net + auto node = std::make_shared(convPrimDesc, src, weights, bias, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getPoolDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] /= 2; // H/2 + dstDims[3] /= 2; // W/2 + return dstDims; + } + + template + std::shared_ptr Network::addPool(const std::shared_ptr& src, + const std::shared_ptr& userDst) + { + const memory::dims kernel = {2, 2}; + const memory::dims strides = {2, 2}; + const memory::dims padding = {0, 0}; + + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getPoolDims(srcDims); + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + auto poolDesc = pooling_forward::desc( + prop_kind::forward_inference, pooling_max, + src->get_desc(), + dst->get_desc(), + strides, kernel, padding, padding, padding_kind::zero); + + mkldnn::primitive_attr poolAttr; + poolAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng); + + auto node = std::make_shared(poolPrimDesc, src, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getUpsampleDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] *= 2; // H*2 + dstDims[3] *= 2; // W*2 + return dstDims; + } + + template + std::shared_ptr Network::addUpsample(const std::shared_ptr& src, + const std::shared_ptr& userDst) + { + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getUpsampleDims(srcDims); + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create upsampling node and add it to net + auto node = std::make_shared>(src, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims) + { + assert(src1Dims[0] == src2Dims[0]); // N + assert(src1Dims[2] == src2Dims[2]); // H + assert(src1Dims[3] == src2Dims[3]); // W + + memory::dims dstDims = src1Dims; + dstDims[1] += src2Dims[1]; // C + return dstDims; + } + + template + std::shared_ptr Network::addAutoexposure(const Image& color, + const std::shared_ptr& transferFunc) + { + auto node = std::make_shared(color, transferFunc); + nodes.push_back(node); + return node; + } + + template + void Network::finalize() + { + // Compute the size of the scratchpad + size_t scratchpadSize = 0; + for (const auto& node : nodes) + scratchpadSize = max(scratchpadSize, node->getScratchpadSize()); + + // Allocate the scratchpad + memory::dims scratchpadDims = { memory::dim(scratchpadSize) }; + memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x); + auto scratchpad = std::make_shared(scratchpadDesc, eng); + activationAllocBytes += scratchpadSize; + totalAllocBytes += scratchpadSize; + + // Set the scratchpad for the nodes + for (auto& node : nodes) + node->setScratchpad(scratchpad); + + // Free the weights + weightMap.clear(); + + // Print statistics + if (device->isVerbose(2)) + { + std::cout << "Activation bytes: " << activationAllocBytes << std::endl; + std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl; + std::cout << "Total bytes : " << totalAllocBytes << std::endl; + } + } + + template class Network<8>; + template class Network<16>; + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.h b/thirdparty/oidn/core/network.h new file mode 100644 index 0000000000..7a696fd355 --- /dev/null +++ b/thirdparty/oidn/core/network.h @@ -0,0 +1,112 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "common/tensor.h" +#include "image.h" +#include "node.h" +#include "input_reorder.h" +#include "output_reorder.h" +#include "transfer_function.h" + +#pragma once + +namespace oidn { + + // Progress state + struct Progress + { + ProgressMonitorFunction func; + void* userPtr; + int taskCount; + }; + + class Executable + { + public: + virtual ~Executable() {} + virtual void execute(const Progress& progress, int taskIndex) = 0; + }; + + template + class Network : public Executable + { + public: + Network(const Ref& device, const std::map& weightMap); + + void execute(const Progress& progress, int taskIndex) override; + + std::shared_ptr allocTensor(const memory::dims& dims, + memory::format_tag format = memory::format_tag::any, + void* data = nullptr); + + std::shared_ptr castTensor(const memory::dims& dims, + const std::shared_ptr& src, + size_t srcOffset = 0, + memory::format_tag format = memory::format_tag::any); + + std::shared_ptr castTensor(const memory::dims& dims, + const std::shared_ptr& src, + const memory::dims& srcOffset); + + void zeroTensor(const std::shared_ptr& dst); + + memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); + + std::shared_ptr addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& transferFunc, + int alignment, + const std::shared_ptr& userDst = nullptr); + + std::shared_ptr addOutputReorder(const std::shared_ptr& src, + const std::shared_ptr& transferFunc, + const Image& output); + + memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); + std::shared_ptr addConv(const std::string& name, + const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr, + bool relu = true); + + memory::dims getPoolDims(const memory::dims& srcDims); + std::shared_ptr addPool(const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr); + + memory::dims getUpsampleDims(const memory::dims& srcDims); + std::shared_ptr addUpsample(const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr); + + memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); + + std::shared_ptr addAutoexposure(const Image& color, + const std::shared_ptr& transferFunc); + + void finalize(); + + private: + Ref device; + engine eng; + stream sm; + std::vector> nodes; + std::map weightMap; + + // Memory allocation statistics + size_t activationAllocBytes = 0; // number of allocated activation bytes + size_t totalAllocBytes = 0; // total number of allocated bytes + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h new file mode 100644 index 0000000000..b9ffe906df --- /dev/null +++ b/thirdparty/oidn/core/node.h @@ -0,0 +1,142 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include + +namespace oidn { + + class Node + { + public: + virtual ~Node() = default; + + virtual void execute(stream& sm) = 0; + + virtual std::shared_ptr getDst() const { return nullptr; } + + virtual size_t getScratchpadSize() const { return 0; } + virtual void setScratchpad(const std::shared_ptr& mem) {} + + virtual void setTile(int h1, int w1, int h2, int w2, int H, int W) + { + assert(0); // not supported + } + }; + + // Node wrapping an MKL-DNN primitive + class MklNode : public Node + { + private: + primitive prim; + std::unordered_map args; + std::shared_ptr scratchpad; + + public: + MklNode(const primitive& prim, const std::unordered_map& args) + : prim(prim), + args(args) + {} + + size_t getScratchpadSize() const override + { + const auto primDesc = prim.get_primitive_desc(); + const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0); + if (scratchpadDesc == nullptr) + return 0; + return mkldnn_memory_desc_get_size(scratchpadDesc); + } + + void setScratchpad(const std::shared_ptr& mem) override + { + scratchpad = mem; + args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad)); + } + + void execute(stream& sm) override + { + prim.execute(sm, args); + } + }; + + // Convolution node + class ConvNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr weights; + std::shared_ptr bias; + std::shared_ptr dst; + + public: + ConvNode(const convolution_forward::primitive_desc& desc, + const std::shared_ptr& src, + const std::shared_ptr& weights, + const std::shared_ptr& bias, + const std::shared_ptr& dst) + : MklNode(convolution_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_WEIGHTS, *weights }, + { MKLDNN_ARG_BIAS, *bias }, + { MKLDNN_ARG_DST, *dst } }), + src(src), weights(weights), bias(bias), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + + // Pooling node + class PoolNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + PoolNode(const pooling_forward::primitive_desc& desc, + const std::shared_ptr& src, + const std::shared_ptr& dst) + : MklNode(pooling_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + + // Reorder node + class ReorderNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + ReorderNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : MklNode(reorder(reorder::primitive_desc(*src, *dst)), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/output_reorder.h b/thirdparty/oidn/core/output_reorder.h new file mode 100644 index 0000000000..7918d48e15 --- /dev/null +++ b/thirdparty/oidn/core/output_reorder.h @@ -0,0 +1,126 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Output reorder node + template + class OutputReorderNode : public Node + { + private: + // Source + std::shared_ptr src; + const float* srcPtr; + int H1; + int W1; + + // Destination + Image output; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr transferFunc; + + public: + OutputReorderNode(const std::shared_ptr& src, + const Image& output, + const std::shared_ptr& transferFunc) + : src(src), + output(output), + h1Begin(0), w1Begin(0), + h2Begin(0), w2Begin(0), + H(output.height), W(output.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(srcDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + // We assume output data is <= K OC + assert(srcDesc.dims[1] == K); + + srcPtr = (float*)src->get_data_handle(); + H1 = srcDesc.dims[2]; + W1 = srcDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(h1Begin + H <= H1); + assert(w1Begin + W <= W1); + assert(h2Begin + H <= output.height); + assert(w2Begin + W <= output.width); + + const int C1 = K; + + parallel_nd(H, [&](int h) + { + const int h1 = h + h1Begin; + const int h2 = h + h2Begin; + + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + float* dstPtr_C = (float*)output.get(h2, w2); + + // Source is in nChwKc format. In this case C is 1 so this is really nhwc + const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1; + + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = srcPtr_C[i]; + + // The CNN output may contain negative values or even NaNs, so it must be sanitized + x = maxSafe(x, 0.f); + + // Apply the inverse transfer function + x = transferFunc->inverse(x); + + // Sanitize and store the final value + dstPtr_C[i] = max(x, 0.f); + } + } + }); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.cpp b/thirdparty/oidn/core/transfer_function.cpp new file mode 100644 index 0000000000..a33e3c84bc --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.cpp @@ -0,0 +1,95 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "transfer_function.h" + +namespace oidn { + + const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f); + const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale); + + float AutoexposureNode::autoexposure(const Image& color) + { + assert(color.format == Format::Float3); + return 1.0f; + + /*constexpr float key = 0.18f; + constexpr float eps = 1e-8f; + constexpr int K = 16; // downsampling amount + + // Downsample the image to minimize sensitivity to noise + const int H = color.height; // original height + const int W = color.width; // original width + const int HK = (H + K/2) / K; // downsampled height + const int WK = (W + K/2) / K; // downsampled width + + // Compute the average log luminance of the downsampled image + using Sum = std::pair; + + Sum sum = + tbb::parallel_reduce( + tbb::blocked_range2d(0, HK, 0, WK), + Sum(0.f, 0), + [&](const tbb::blocked_range2d& r, Sum sum) -> Sum + { + // Iterate over blocks + for (int i = r.rows().begin(); i != r.rows().end(); ++i) + { + for (int j = r.cols().begin(); j != r.cols().end(); ++j) + { + // Compute the average luminance in the current block + const int beginH = int(ptrdiff_t(i) * H / HK); + const int beginW = int(ptrdiff_t(j) * W / WK); + const int endH = int(ptrdiff_t(i+1) * H / HK); + const int endW = int(ptrdiff_t(j+1) * W / WK); + + float L = 0.f; + + for (int h = beginH; h < endH; ++h) + { + for (int w = beginW; w < endW; ++w) + { + const float* rgb = (const float*)color.get(h, w); + + const float r = maxSafe(rgb[0], 0.f); + const float g = maxSafe(rgb[1], 0.f); + const float b = maxSafe(rgb[2], 0.f); + + L += luminance(r, g, b); + } + } + + L /= (endH - beginH) * (endW - beginW); + + // Accumulate the log luminance + if (L > eps) + { + sum.first += log2(L); + sum.second++; + } + } + } + + return sum; + }, + [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, + tbb::static_partitioner() + ); + + return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;*/ + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.h b/thirdparty/oidn/core/transfer_function.h new file mode 100644 index 0000000000..35f2833092 --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.h @@ -0,0 +1,201 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "image.h" +#include "node.h" + +namespace oidn { + + __forceinline float luminance(float r, float g, float b) + { + return 0.212671f * r + 0.715160f * g + 0.072169f * b; + } + + // Color transfer function base class + class TransferFunction + { + public: + virtual ~TransferFunction() = default; + + virtual float forward(float y) const = 0; + virtual float inverse(float x) const = 0; + }; + + // HDR transfer function base class + class HDRTransferFunction : public TransferFunction + { + protected: + static constexpr float yMax = 65504.f; + + float exposure; + float rcpExposure; + + public: + HDRTransferFunction(float exposure = 1.f) + { + setExposure(exposure); + } + + void setExposure(float exposure) + { + this->exposure = exposure; + this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f; + } + }; + + // Linear transfer function (LDR) + class LinearTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(y, 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(x, 1.f); + } + }; + + // 2.2 gamma transfer function (LDR) + class GammaTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(pow(y, 1.f/2.2f), 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(pow(x, 2.2f), 1.f); + } + }; + + // Logarithmic transfer function (HDR) + // Compresses [0..65504] to [0..1] + class LogTransferFunction : public HDRTransferFunction + { + private: + static const float xScale; + + public: + LogTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return log(y * exposure + 1.f) * xScale; + } + + __forceinline float inverse(float x) const override + { + return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure; + } + }; + + // PQX transfer function (HDR) + // Compresses [0..65504] to [0..1] + class PQXTransferFunction : public HDRTransferFunction + { + private: + static constexpr float m1 = 2610.f / 4096.f / 4.f; + static constexpr float m2 = 2523.f / 4096.f * 128.f; + static constexpr float c1 = 3424.f / 4096.f; + static constexpr float c2 = 2413.f / 4096.f * 32.f; + static constexpr float c3 = 2392.f / 4096.f * 32.f; + static constexpr float a = 3711.f / 4096.f / 8.f; + + static constexpr float yScale = 100.f / 10000.f; + static const float xScale; + + public: + PQXTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return pqxForward(y * exposure * yScale) * xScale; + } + + __forceinline float inverse(float x) const override + { + return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure; + } + + private: + static __forceinline float pqForward(float y) + { + const float yp = pow(y, m1); + return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2); + } + + static __forceinline float pqxForward(float y) + { + if (y <= 1.f) + return pqForward(y); + else + return a * log(y) + 1.f; + } + + static __forceinline float pqInverse(float x) + { + const float xp = pow(x, 1.f/m2); + return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1); + } + + static __forceinline float pqxInverse(float x) + { + if (x <= 1.f) + return pqInverse(x); + else + return exp((x - 1.f) * (1.f/a)); + } + }; + + // Autoexposure node + class AutoexposureNode : public Node + { + private: + Image color; + std::shared_ptr transferFunc; + + public: + AutoexposureNode(const Image& color, + const std::shared_ptr& transferFunc) + : color(color), + transferFunc(transferFunc) + {} + + void execute(stream& sm) override + { + const float exposure = autoexposure(color); + //printf("exposure = %f\n", exposure); + transferFunc->setExposure(exposure); + } + + private: + static float autoexposure(const Image& color); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/upsample.h b/thirdparty/oidn/core/upsample.h new file mode 100644 index 0000000000..f6cace44cd --- /dev/null +++ b/thirdparty/oidn/core/upsample.h @@ -0,0 +1,92 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // 2x2 nearest-neighbor upsampling node + template + class UpsampleNode : public Node + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + UpsampleNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + assert(dstDesc.dims[0] == 1); + // 2x2 upsampling + assert(dstDesc.dims[2] == srcDesc.dims[2] * 2); + assert(dstDesc.dims[3] == srcDesc.dims[3] * 2); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int C = srcDesc.dims[1]; + const int H = srcDesc.dims[2]; + const int W = srcDesc.dims[3]; + const int CK = C / K; + + parallel_nd(CK, H, [&](int ck, int h) + { + const size_t offset = ck*H*W*K + h*W*K; + const float* srcPtr_line = srcPtr + offset; + float* dstPtr_line0 = dstPtr + offset * 4; + float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line + + for (int w = 0; w < W; ++w) + { + #pragma unroll + for (int k = 0; k < K; k += 4) + { + const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]); + + _mm_stream_ps(&dstPtr_line0[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m); + } + } + }); + } + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/weights_reorder.h b/thirdparty/oidn/core/weights_reorder.h new file mode 100644 index 0000000000..6c5dacb8aa --- /dev/null +++ b/thirdparty/oidn/core/weights_reorder.h @@ -0,0 +1,99 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // Reorders weights from oihw to padded oihw format + template + class WeightsReorderNode : public Node + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + WeightsReorderNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(getPadded(srcDesc.dims[0]) == dstDesc.dims[0]); // OC + assert(getPadded(srcDesc.dims[1]) == dstDesc.dims[1]); // IC + assert(srcDesc.dims[2] == dstDesc.dims[2]); + assert(srcDesc.dims[3] == dstDesc.dims[3]); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int OC1 = srcDesc.dims[0]; + const int OC2 = dstDesc.dims[0]; + const int IC1 = srcDesc.dims[1]; + const int IC2 = dstDesc.dims[1]; + const int H = dstDesc.dims[2]; + const int W = dstDesc.dims[3]; + + for (int oc = 0; oc < OC2; ++oc) + { + for (int ic = 0; ic < IC2; ++ic) + { + for (int h = 0; h < H; ++h) + { + for (int w = 0; w < W; ++w) + { + // Output is in oihw format + float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w; + + if (oc < OC1 && ic < IC1) + { + // Input is in oihw format + const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w; + *dstPtr_c = *srcPtr_c; + } + else + { + // padding + *dstPtr_c = 0; + } + } + } + } + } + } + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.h b/thirdparty/oidn/include/OpenImageDenoise/oidn.h new file mode 100644 index 0000000000..57ba6baa21 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.h @@ -0,0 +1,214 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include +#include + +#include "version.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#ifndef OIDN_API +#if defined(_WIN32) && !defined(OIDN_STATIC_LIB) +# define OIDN_API __declspec(dllimport) +#else +# define OIDN_API +#endif +#endif + +// ---------------------------------------------------------------------------- +// Device +// ---------------------------------------------------------------------------- + +// Device types +typedef enum +{ + OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically + + OIDN_DEVICE_TYPE_CPU = 1, // CPU device +} OIDNDeviceType; + +// Error codes +typedef enum +{ + OIDN_ERROR_NONE = 0, // no error occurred + OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred + OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified + OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed + OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation + OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported + OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user +} OIDNError; + +// Error callback function +typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message); + +// Device handle +typedef struct OIDNDeviceImpl* OIDNDevice; + +// Creates a new device. +OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type); + +// Retains the device (increments the reference count). +OIDN_API void oidnRetainDevice(OIDNDevice device); + +// Releases the device (decrements the reference count). +OIDN_API void oidnReleaseDevice(OIDNDevice device); + +// Sets a boolean parameter of the device. +OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value); + +// Sets an integer parameter of the device. +OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value); + +// Gets a boolean parameter of the device. +OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name); + +// Gets an integer parameter of the device (e.g. "version"). +OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name); + +// Sets the error callback function of the device. +OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr); + +// Returns the first unqueried error code stored in the device for the current +// thread, optionally also returning a string message (if not NULL), and clears +// the stored error. Can be called with a NULL device as well to check why a +// device creation failed. +OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage); + +// Commits all previous changes to the device. +// Must be called before first using the device (e.g. creating filters). +OIDN_API void oidnCommitDevice(OIDNDevice device); + +// ---------------------------------------------------------------------------- +// Buffer +// ---------------------------------------------------------------------------- + +// Formats for images and other data stored in buffers +typedef enum +{ + OIDN_FORMAT_UNDEFINED = 0, + + // 32-bit single-precision floating point scalar and vector formats + OIDN_FORMAT_FLOAT = 1, + OIDN_FORMAT_FLOAT2 = 2, + OIDN_FORMAT_FLOAT3 = 3, + OIDN_FORMAT_FLOAT4 = 4, +} OIDNFormat; + +// Access modes for mapping buffers +typedef enum +{ + OIDN_ACCESS_READ = 0, // read-only access + OIDN_ACCESS_WRITE = 1, // write-only access + OIDN_ACCESS_READ_WRITE = 2, // read and write access + OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded +} OIDNAccess; + +// Buffer handle +typedef struct OIDNBufferImpl* OIDNBuffer; + +// Creates a new buffer (data allocated and owned by the device). +OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize); + +// Creates a new shared buffer (data allocated and owned by the user). +OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize); + +// Maps a region of the buffer to host memory. +// If byteSize is 0, the maximum available amount of memory will be mapped. +OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize); + +// Unmaps a region of the buffer. +// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer. +OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr); + +// Retains the buffer (increments the reference count). +OIDN_API void oidnRetainBuffer(OIDNBuffer buffer); + +// Releases the buffer (decrements the reference count). +OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer); + +// ---------------------------------------------------------------------------- +// Filter +// ---------------------------------------------------------------------------- + +// Progress monitor callback function +typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n); + +// Filter handle +typedef struct OIDNFilterImpl* OIDNFilter; + +// Creates a new filter of the specified type (e.g. "RT"). +OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type); + +// Retains the filter (increments the reference count). +OIDN_API void oidnRetainFilter(OIDNFilter filter); + +// Releases the filter (decrements the reference count). +OIDN_API void oidnReleaseFilter(OIDNFilter filter); + +// Sets an image parameter of the filter (stored in a buffer). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name, + OIDNBuffer buffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets an image parameter of the filter (owned by the user). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets a boolean parameter of the filter. +OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value); + +// Gets a boolean parameter of the filter. +OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name); + +// Sets an integer parameter of the filter. +OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value); + +// Gets an integer parameter of the filter. +OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name); + +// Sets a float parameter of the filter. +OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value); + +// Gets a float parameter of the filter. +OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name); + +// Sets the progress monitor callback function of the filter. +OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr); + +// Commits all previous changes to the filter. +// Must be called before first executing the filter. +OIDN_API void oidnCommitFilter(OIDNFilter filter); + +// Executes the filter. +OIDN_API void oidnExecuteFilter(OIDNFilter filter); + +#if defined(__cplusplus) +} +#endif diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp new file mode 100644 index 0000000000..9f95a56fe1 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp @@ -0,0 +1,468 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include "oidn.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // Buffer + // -------------------------------------------------------------------------- + + // Formats for images and other data stored in buffers + enum class Format + { + Undefined = OIDN_FORMAT_UNDEFINED, + + // 32-bit single-precision floating point scalar and vector formats + Float = OIDN_FORMAT_FLOAT, + Float2 = OIDN_FORMAT_FLOAT2, + Float3 = OIDN_FORMAT_FLOAT3, + Float4 = OIDN_FORMAT_FLOAT4, + }; + + // Access modes for mapping buffers + enum class Access + { + Read = OIDN_ACCESS_READ, // read-only access + Write = OIDN_ACCESS_WRITE, // write-only access + ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access + WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded + }; + + // Buffer object with automatic reference counting + class BufferRef + { + private: + OIDNBuffer handle; + + public: + BufferRef() : handle(nullptr) {} + BufferRef(OIDNBuffer handle) : handle(handle) {} + + BufferRef(const BufferRef& other) : handle(other.handle) + { + if (handle) + oidnRetainBuffer(handle); + } + + BufferRef(BufferRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + BufferRef& operator =(const BufferRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainBuffer(other.handle); + if (handle) + oidnReleaseBuffer(handle); + handle = other.handle; + } + return *this; + } + + BufferRef& operator =(BufferRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + BufferRef& operator =(OIDNBuffer other) + { + if (other) + oidnRetainBuffer(other); + if (handle) + oidnReleaseBuffer(handle); + handle = other; + return *this; + } + + ~BufferRef() + { + if (handle) + oidnReleaseBuffer(handle); + } + + OIDNBuffer getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Maps a region of the buffer to host memory. + // If byteSize is 0, the maximum available amount of memory will be mapped. + void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0) + { + return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize); + } + + // Unmaps a region of the buffer. + // mappedPtr must be a pointer returned by a previous call to map. + void unmap(void* mappedPtr) + { + oidnUnmapBuffer(handle, mappedPtr); + } + }; + + // -------------------------------------------------------------------------- + // Filter + // -------------------------------------------------------------------------- + + // Progress monitor callback function + typedef bool (*ProgressMonitorFunction)(void* userPtr, double n); + + // Filter object with automatic reference counting + class FilterRef + { + private: + OIDNFilter handle; + + public: + FilterRef() : handle(nullptr) {} + FilterRef(OIDNFilter handle) : handle(handle) {} + + FilterRef(const FilterRef& other) : handle(other.handle) + { + if (handle) + oidnRetainFilter(handle); + } + + FilterRef(FilterRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + FilterRef& operator =(const FilterRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainFilter(other.handle); + if (handle) + oidnReleaseFilter(handle); + handle = other.handle; + } + return *this; + } + + FilterRef& operator =(FilterRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + FilterRef& operator =(OIDNFilter other) + { + if (other) + oidnRetainFilter(other); + if (handle) + oidnReleaseFilter(handle); + handle = other; + return *this; + } + + ~FilterRef() + { + if (handle) + oidnReleaseFilter(handle); + } + + OIDNFilter getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets an image parameter of the filter (stored in a buffer). + void setImage(const char* name, + const BufferRef& buffer, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetFilterImage(handle, name, + buffer.getHandle(), (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets an image parameter of the filter (owned by the user). + void setImage(const char* name, + void* ptr, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetSharedFilterImage(handle, name, + ptr, (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets a boolean parameter of the filter. + void set(const char* name, bool value) + { + oidnSetFilter1b(handle, name, value); + } + + // Sets an integer parameter of the filter. + void set(const char* name, int value) + { + oidnSetFilter1i(handle, name, value); + } + + // Sets a float parameter of the filter. + void set(const char* name, float value) + { + oidnSetFilter1f(handle, name, value); + } + + // Gets a parameter of the filter. + template + T get(const char* name); + + // Sets the progress monitor callback function of the filter. + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr) + { + oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr); + } + + // Commits all previous changes to the filter. + void commit() + { + oidnCommitFilter(handle); + } + + // Executes the filter. + void execute() + { + oidnExecuteFilter(handle); + } + }; + + // Gets a boolean parameter of the filter. + template<> + inline bool FilterRef::get(const char* name) + { + return oidnGetFilter1b(handle, name); + } + + // Gets an integer parameter of the filter. + template<> + inline int FilterRef::get(const char* name) + { + return oidnGetFilter1i(handle, name); + } + + // Gets a float parameter of the filter. + template<> + inline float FilterRef::get(const char* name) + { + return oidnGetFilter1f(handle, name); + } + + // -------------------------------------------------------------------------- + // Device + // -------------------------------------------------------------------------- + + // Device types + enum class DeviceType + { + Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically + + CPU = OIDN_DEVICE_TYPE_CPU, // CPU device + }; + + // Error codes + enum class Error + { + None = OIDN_ERROR_NONE, // no error occurred + Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred + InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified + InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed + OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation + UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported + Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user + }; + + // Error callback function + typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message); + + // Device object with automatic reference counting + class DeviceRef + { + private: + OIDNDevice handle; + + public: + DeviceRef() : handle(nullptr) {} + DeviceRef(OIDNDevice handle) : handle(handle) {} + + DeviceRef(const DeviceRef& other) : handle(other.handle) + { + if (handle) + oidnRetainDevice(handle); + } + + DeviceRef(DeviceRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + DeviceRef& operator =(const DeviceRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainDevice(other.handle); + if (handle) + oidnReleaseDevice(handle); + handle = other.handle; + } + return *this; + } + + DeviceRef& operator =(DeviceRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + DeviceRef& operator =(OIDNDevice other) + { + if (other) + oidnRetainDevice(other); + if (handle) + oidnReleaseDevice(handle); + handle = other; + return *this; + } + + ~DeviceRef() + { + if (handle) + oidnReleaseDevice(handle); + } + + OIDNDevice getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets a boolean parameter of the device. + void set(const char* name, bool value) + { + oidnSetDevice1b(handle, name, value); + } + + // Sets an integer parameter of the device. + void set(const char* name, int value) + { + oidnSetDevice1i(handle, name, value); + } + + // Gets a parameter of the device. + template + T get(const char* name); + + // Sets the error callback function of the device. + void setErrorFunction(ErrorFunction func, void* userPtr = nullptr) + { + oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr); + } + + // Returns the first unqueried error code and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError() + { + return (Error)oidnGetDeviceError(handle, nullptr); + } + + // Returns the first unqueried error code and string message, and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError(const char*& outMessage) + { + return (Error)oidnGetDeviceError(handle, &outMessage); + } + + // Commits all previous changes to the device. + // Must be called before first using the device (e.g. creating filters). + void commit() + { + oidnCommitDevice(handle); + } + + // Creates a new buffer (data allocated and owned by the device). + BufferRef newBuffer(size_t byteSize) + { + return oidnNewBuffer(handle, byteSize); + } + + // Creates a new shared buffer (data allocated and owned by the user). + BufferRef newBuffer(void* ptr, size_t byteSize) + { + return oidnNewSharedBuffer(handle, ptr, byteSize); + } + + // Creates a new filter of the specified type (e.g. "RT"). + FilterRef newFilter(const char* type) + { + return oidnNewFilter(handle, type); + } + }; + + // Gets a boolean parameter of the device. + template<> + inline bool DeviceRef::get(const char* name) + { + return oidnGetDevice1b(handle, name); + } + + // Gets an integer parameter of the device (e.g. "version"). + template<> + inline int DeviceRef::get(const char* name) + { + return oidnGetDevice1i(handle, name); + } + + // Creates a new device. + inline DeviceRef newDevice(DeviceType type = DeviceType::Default) + { + return DeviceRef(oidnNewDevice((OIDNDeviceType)type)); + } + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/version.h b/thirdparty/oidn/include/OpenImageDenoise/version.h new file mode 100644 index 0000000000..66b347c992 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/version.h @@ -0,0 +1,23 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#define OIDN_VERSION_MAJOR 1 +#define OIDN_VERSION_MINOR 1 +#define OIDN_VERSION_PATCH 0 +#define OIDN_VERSION 10100 +#define OIDN_VERSION_STRING "1.1.0" diff --git a/thirdparty/oidn/mkl-dnn/LICENSE b/thirdparty/oidn/mkl-dnn/LICENSE new file mode 100644 index 0000000000..d13f7b7ca0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/LICENSE @@ -0,0 +1,214 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ============================================================================ + + Intel MKL-DNN includes components with separate copyright + notices and license terms. + + XByak, 3-clause BSD license + Copyright (c) 2007 MITSUNARI Shigeo + See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT + + gtest, 3-clause BSD license + Copyright 2008, Google Inc. + See full copyright notice and license text in tests/gtests/gtest/LICENSE diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.h b/thirdparty/oidn/mkl-dnn/include/mkldnn.h new file mode 100644 index 0000000000..9b64994922 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.h @@ -0,0 +1,1771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_H +#define MKLDNN_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#include "mkldnn_version.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +/** @addtogroup c_api C API + * @{ */ + +/** @addtogroup c_api_primitive Primitive operations + * @{ */ + +/** @addtogroup c_api_primitive_common Common primitive operations + * @{ */ + +/** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr, + * @p engine, and optionally a hint primitive descriptor from forward + * propagation (required for backward propagation). Pass @c NULL for forward + * propagation. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create( + mkldnn_primitive_desc_iterator_t *iterator, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no + * more primitive descriptors are available. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next( + mkldnn_primitive_desc_iterator_t iterator); + +/** Fetches the current primitive descriptor. + * + * @note + * The user should delete the fetched primitive descriptor using + * mkldnn_primitive_desc_destroy() once it is no longer needed. */ +mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch( + const_mkldnn_primitive_desc_iterator_t iterator); + +/** Deletes a primitive descriptor @p iterator */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy( + mkldnn_primitive_desc_iterator_t iterator); + +/** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and + * optionally a hint primitive descriptor from forward propagation. The call is + * equivalent to creating a primitive descriptor iterator, immediately fetching + * a primitive descriptor, and then destroying the iterator. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Makes a copy of a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_primitive_desc_t existing_primitive_desc); + +/** Returns a constant reference to the attribute of a @p primitive_desc. + * + * @warning + * The user should not destroy the obtained @p attr. + * + * @warning + * The lifetime of an @p attr is the same as that of a @p primitive_desc, + * so it is illegal to use the @p attr once @p primitive_desc has been + * destroyed. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr( + const_mkldnn_primitive_desc_t primitive_desc, + const_mkldnn_primitive_attr_t *attr); + +/** Deletes a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy( + mkldnn_primitive_desc_t primitive_desc); + +/** Queries primitive descriptor + * + * One of the most typical use cases is to query a convolution primitive + * descriptor created with source, weights, and destination formats equal + * to #mkldnn_format_tag_any about the corresponding memory descriptors + * (@p what equals #mkldnn_query_src_md, #mkldnn_query_weights_md, and + * #mkldnn_query_dst_md respectively) to be able to prepare memory and + * create reorders if required. + * + * Another quite typical use case is to query an operation primitive + * descriptor for a workspace (@p what equals #mkldnn_query_workspace_md). + * The returned status #mkldnn_not_required indicates that a workspace is + * not required. + * + * A few other possibilities: + * - query an operation primitive descriptor for the underlying operation + * descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d, + * #mkldnn_query_rnn_d, etc.) + * - query an operation primitive descriptor for the implementation + * information string (#mkldnn_query_impl_info_str) + * - query an operation primitive descriptor for the number of inputs and + * outputs (#mkldnn_query_num_of_inputs_s32 and + * #mkldnn_query_num_of_outputs_s32 respectively) + * + * @sa mkldnn_query_t for more options + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index, void *result); + +/** Queries primitive descriptor for memory descriptor + * + * @returns NULL in case of any error. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_md( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Queries primitive descriptor for signed 32bit int + * + * @returns 0 in case of any error (in particular if the queried entity is + * not of type int32_t). Note that 0 might also be the actual returned + * value. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +int MKLDNN_API mkldnn_primitive_desc_query_s32( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Creates a @p primitive using a @p primitive_desc descriptor. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_create( + mkldnn_primitive_t *primitive, + const_mkldnn_primitive_desc_t primitive_desc); + +/** Executes a @p primitive using a @p stream, and @p nargs arguments + * @p args. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_execute( + const_mkldnn_primitive_t primitive, mkldnn_stream_t stream, + int nargs, const mkldnn_exec_arg_t *args); + +/** Retrieves a reference to the @p primitive_desc descriptor of given @p + * primitive. + * + * @warning + * The returned object must not be destroyed by the user. The @c const + * qualifier of the returned object prevents such attempts. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc( + const_mkldnn_primitive_t primitive, + const_mkldnn_primitive_desc_t *primitive_desc); + +/** Deletes a @p primitive. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy( + mkldnn_primitive_t primitive); + +/** @} */ + +/** @addtogroup c_api_attributes Attributes + * An extension for controlling primitive behavior. + * @{ */ + +/** Creates an empty (default) @p attr attribute. All the parameters are set to + * default values. + * + * An empty attribute is used in primitive descriptor creation whenever it + * is not passed explicitly, e.g. in mkldnn_primitive_desc_create. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create( + mkldnn_primitive_attr_t *attr); + +/** Makes a copy of an @p existing_attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone( + mkldnn_primitive_attr_t *attr, + const_mkldnn_primitive_attr_t existing_attr); + +/** Deletes an @p attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy( + mkldnn_primitive_attr_t attr); + +/** Returns the scratchpad @p mode set in the attribute @p attr */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_scratchpad_mode( + const_mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t *mode); + +/** Sets scratchpad @p mode. + * + * The possible values are: #mkldnn_scratchpad_mode_library (default) and + * #mkldnn_scratchpad_mode_user. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_scratchpad_mode( + mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t mode); + +/** Returns @p count, correspondence scale @p mask, and a pointer to a constant + * floating point array of output @p scales for given @p attr, previously set + * by mkldnn_primitive_attr_set_output_scales. + * + * @warning + * The @p scales array points to the internal @p attr field, so the user + * should not modify or destroy @p scales. + * + * @warning + * The lifetime of @p scales is the same as that of the @p attr to which it + * belongs, so it is illegal to use @p scales after @p attr is destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales( + const_mkldnn_primitive_attr_t attr, mkldnn_dim_t *count, int *mask, + const float **scales); + +/** Sets output @p scales for primitive operations. The number of elements @p + * count and correspondence scale @p mask are stored for future use. + * + * The @p mask argument defines the correspondence between the output tensor + * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a + * dedicated scaling factor for each slice of the output tensor over the i-th + * dimension. Set @p mask to 0 to use a common scaling factor for the whole + * output tensor. + * + * @note + * The dimension order is always native and does not depend on the actual + * layout used. Examples: + * - 2D dimensional data the order of dimensions is always: (n, c) + * - 4D dimensional data the order is always: (n, c, h, w) + * - 5D dimensional weights the order is always: (g, oc, ic, kh, kw) + * + * Example usage: + * @code + * int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params + * float scales[oc] = { ... }; // unique output scales per output channel + * int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... + * + * mkldnn_convolution_desc_t cd; // create & configure convolution op_desc + * + * mkldnn_primitive_attr_t attr; + * mkldnn_primitive_attr_create(&attr); // create default attributes + * mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales); + * + * mkldnn_primitive_desc_t cpd; + * mkldnn_primitive_desc_create(&cpd, &cd, attr, NULL); + * @endcode + * + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is the user's + * responsibility to set proper values. The following formula must hold: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *scales); + +/** Returns @p post_ops for given @p attr. + * + * @warning + * @p post_ops points to the internal @p attr field, so the user should not + * modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the + * same as that of the @p attr it belongs to, so it is illegal to use @p + * post_ops after @p attr has been destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops( + const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops); + +/** Sets configured @p post_ops to an attribute @p attr for future use (when + * primitive descriptor is being created). + * + * @note + * At this point in time, there is no way to check whether the primitive + * descriptor does or does not support a given sequence of post operations. + * Therefore the user should handle an error that might occur at the + * mkldnn_primitive_desc_create call. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops( + mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops); + +/** @addtogroup c_api_attributes_post_ops Sequence of post operations + * An extension for performing extra operations after a base operation. + * @{ */ + +/** Creates an empty sequence of post operations @p post_ops. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops); + +/** Deletes a @p post_ops sequence. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops); + +/** Returns the @p length of post operations for given @p post_ops. */ +int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops); + +/** Returns the type of post operation with index @p index in given + * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */ +mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind( + const_mkldnn_post_ops_t post_ops, int index); + +/** Appends accumulation (sum) post operation to the @p post_ops. Prior to + * accumulating the result, the previous value would be multiplied by @p scale. + * + * The kind of this post operation is #mkldnn_sum. + * + * This feature might improve performance for cases like residual learning + * blocks, where the result of convolution is accumulated to the previously + * computed activations. The parameter @p scale might be extreme for the + * integer-based computations when the result and previous activations have + * different logical scaling factors. + * + * In the simplest case when the accumulation is the only post operation, the + * computations would be: + * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...) + * + * @note + * This post operation (as well as all the others) disregards the original + * layout of the destination; that is, the layout of the original + * destination is expected to be the same as the layout of the stored + * destination. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum( + mkldnn_post_ops_t post_ops, float scale); + +/** Gets the parameters of the accumulation (sum) post operation with index + * @p index in the sequence of @p post_ops. + * + * @note + * If index @p index would not correspond to the accumulation post + * operation, the function returns #mkldnn_invalid_arguments. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum( + const_mkldnn_post_ops_t post_ops, int index, float *scale); + +/** Appends eltwise post operation to the @p post_ops with given parameters + * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and + * mkldnn_eltwise_desc_t). + * + * The kind of this post operation is #mkldnn_eltwise. + * + * In the simplest case when the eltwise is the only post operation, the + * computations would be: + * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...) + * where eltwise_op is configured with the given parameters. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise( + mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, + float alpha, float beta); + +/** Gets the eltwise parameters of the post operation with index @p index in + * the sequence of @p post_ops. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise( + const_mkldnn_post_ops_t post_ops, int index, float *scale, + mkldnn_alg_kind_t *alg, float *alpha, float *beta); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_memory Memory + * A primitive to describe and store data. + * + * The library supports various data types and formats. Memory hierarchy + * consists of three levels of abstraction: + * 1. **Memory descriptor** -- engine agnostic logical description of data + * (number of dimensions, dimensions themselves, and data type), and + * optionally the format/layout that describes the physical representation + * of data in memory. If the format is not known yet, one can pass + * #mkldnn_format_tag_any. This approach is used to allow compute-intensive + * primitives to specify the most appropriate format on their own with + * users required to reorder the data if the incoming format doesn't match + * the primitive's selection. Memory descriptor can be initialized with + * mkldnn_memory_desc_init_by_tag() or mkldnn_memory_desc_init_by_strides() + * functions, or by directly filling the mkldnn_memory_desc_t structure. + * The latter requires deep knowledge of how the physical data + * representation is mapped to the structure. + * The @ref understanding_memory_formats topic should shed some light on + * that. + * For the fully defined memory descriptors (i.e. where the format kind is + * not equal to #mkldnn_format_kind_any) a user can the size, using the + * mkldnn_memory_desc_get_size() function. As described in + * @ref understanding_memory_formats, the size of data sometimes cannot + * be computed as the product of dimensions times the size of the data + * type. So users are encouraged to use this function for better code + * portability. + * Two memory descriptors can be compared with mkldnn_memory_desc_equal(). + * The comparison is especially useful when checking whether a primitive + * requires reorder from the user's data format to the primitive's format. + * 2. **Memory** -- an engine-specific object that handles the data and its + * description (a memory descriptor). For CPU enigne, the data handle is + * simply a pointer to @c void. The data handle can be queried using + * mkldnn_memory_get_data_handle() and set using + * mkldnn_memory_set_data_handle(). The latter function always sets the + * memory in the padding region to zero, which is the invariant maintained + * by all the primitives in Intel MKL-DNN. + * See @ref understanding_memory_formats for more details. + * A memory can be created using mkldnn_memory_create() function. + * A memory can also be queried for the underlying memory descriptor and + * engine using mkldnn_memory_get_memory_desc() and + * mkldnn_memory_get_engine() functions. + * + * Along with ordinary memory with all dimensions being positive, Intel + * MKL-DNN supports *zero-volume* memory with one or more dimensions set to + * zero. This is to support the NumPy\* convention. + * If a *zero-volume* memory is passed to a primitive, the primitive does + * not perform any computations on this memory. For example: + * - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)` + * source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)` + * weights would produce `(0 batch, 16 output channels, 11 height, 11 width)` + * destination (assuming strides are `1` and paddings are zero) and perform + * zero multiply-add operations. + * - Concatenation of three memories of shapes `(3, 4, 13, 13)`, + * `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce + * the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second + * input (however, if the user created a concatenation primitive descriptor + * with three inputs they should also provide all three memories to the + * concatenation primitive, including the one with zero second dimension). + * - However, Intel MKL-DNN would return an error when attempting to create a + * convolution with *zero-volume* memory passed for weights because such a + * convolution is not well-defined: + * ~~~ + * dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3) + * ~~~ + * Should the values in the destination be zeroes or just not accessed at + * all? Moreover, backward pass w.r.t. weights in such cases is also not + * well-defined. + * + * Data handle of *zero-volume* memory is never accessed and hence can be + * unset (NULL in case of CPU engine). + * + * @sa @ref understanding_memory_formats + * @{ */ + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and @p strides. + * + * The @p strides might be NULL, which means the order of physical dimensions + * is the same as the order of logical ones. + * + * @note The logical order of dimensions is defined by a primitive that + * consumes the memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_strides( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, const mkldnn_dims_t strides); + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and format @p tag. + * + * @p tag can be #mkldnn_format_tag_any, which allows a primitive to define + * the appropriate memory format. In this case, the @p format_kind would be set + * to #mkldnn_format_kind_any */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_tag( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, mkldnn_format_tag_t tag); + +/** Initializes a @p memory_desc for a given @p parent_memory_desc, with + * @p dims sizes and @p offsets. May fail if layout used does not allow + * obtain desired submemory. In this case consider using `extract` or `insert` + * primitive */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_submemory( + mkldnn_memory_desc_t *memory_desc, + const mkldnn_memory_desc_t *parent_memory_desc, + const mkldnn_dims_t dims, const mkldnn_dims_t offsets); + +/** Compares two memory descriptors. + * @return 1 if the descriptors are the same. + * @return 0 if the descriptors are different. + * + * Use this function to identify whether a reorder is required between the + * two memories */ +int MKLDNN_API mkldnn_memory_desc_equal( + const mkldnn_memory_desc_t *lhs, + const mkldnn_memory_desc_t *rhs); + +/** Returns the size (in bytes) that is required for given @p memory_desc */ +size_t MKLDNN_API mkldnn_memory_desc_get_size( + const mkldnn_memory_desc_t *memory_desc); + +/** Creates a memory for given @p memory_desc and @p engine. Also sets handle + * to @p native_handle. + * The @p native_handle can: + * - point to the user allocated memory, i.e. valid handle. In this case the + * library doesn't own allocated memory. + * - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and + * attach memory. In this case the library owns allocated memory. + * - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_create(mkldnn_memory_t *memory, + const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine, + void *native_handle); + +/** Returns a @p memory_desc associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_memory_desc( + const_mkldnn_memory_t memory, + const mkldnn_memory_desc_t **memory_desc); + +/** Returns an @p engine associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_engine( + const_mkldnn_memory_t memory, mkldnn_engine_t *engine); + +/** For a @p memory, returns the data @p handle. + * + * For the CPU engine, the data handle is a pointer to the actual data. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle( + const_mkldnn_memory_t memory, void **handle); + +/** For a @p memory, sets the data @p handle. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle( + mkldnn_memory_t memory, void *handle); + +/** Deletes a @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_destroy(mkldnn_memory_t memory); + +/** @} */ + +/** @addtogroup c_api_reorder Reorder + * A primitive to copy data between memory formats. + * @{ */ + +/** Initializes a @p reorder_primitive_desc using the description of the source + * (@p src_engine and @p src_md) and destination (@p dst_engine and @p dst_md) + * memory, and an @p attr attribute. + * + * Inputs: + * - input (#mkldnn_query_src_md, 0) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create( + mkldnn_primitive_desc_t *reorder_primitive_desc, + mkldnn_engine_t src_engine, const mkldnn_memory_desc_t *src_md, + mkldnn_engine_t dst_engine, const mkldnn_memory_desc_t *dst_md, + const_mkldnn_primitive_attr_t attr); + +/** @} */ + +/** @addtogroup c_api_concat Concat + * A primitive to concatenate data by arbitrary dimension. + * @{ */ + +/** Creates out-of-place @p concat_primitive_desc for concatenation of @p n + * inputs by @p concat_dimension with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - input 0 (#mkldnn_query_src_md, 0) + * - input 1 (#mkldnn_query_src_md, 1) + * - ... + * - input @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create( + mkldnn_primitive_desc_t *concat_primitive_desc, + const mkldnn_memory_desc_t *dst_md, + int n, int concat_dimension, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_sum Sum + * A primitive to sum data. + * @{ */ + +/** Creates out-of-place @p sum_primitive_desc for sum of @p n + * inputs multiplied by scale with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - src 0 (#mkldnn_query_src_md, 0) + * - src 1 (#mkldnn_query_src_md, 1) + * - ... + * - src @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create( + mkldnn_primitive_desc_t *sum_primitive_desc, + const mkldnn_memory_desc_t *dst_mds, + int n, const float *scales, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_convolution Convolution + * A primitive to compute convolution using different algorithms. + * + * \f[dst[n][oc][oh][ow] = + * \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw] + * \cdot weights[g][oc][ic][kh][kw] + * + bias[g][oc],\f] + * + * where size of output spatial domain is given by + * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}} + * \right\rfloor + 1\f$, + * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}} + * \right\rfloor + 1\f$, + * + * and summation is carried over input channels \f$ic\f$ in + * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and + * \f$p_l, p_r\f$ are @p padding_l and @p padding_r. + * @{ */ + +/** Initializes a convolution descriptor @p conv_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. In order to create a + * convolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equal to #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * In order to create a dilated convolution without bias, @p bias_desc + * should either be @c NULL or point to a descriptor with memory format kind + * equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API +mkldnn_dilated_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_deconvolution Deconvolution + * A primitive to compute deconvolution using different algorithms. + * + * @{ */ + + +/** Initializes a deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. In order to create a + * deconvolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to + * create a dilated deconvolution without bias, @p bias_desc should either be + * @c NULL or point to a descriptor with memory format kind equal + * #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to weights using @p alg_kind, memory descriptors, + * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_shuffle Shuffle + * A primitive to shuffle data along the axis. + * @{ */ + +/** Initializes a @p shuffle_desc for forward propagation using @p prop_kind, + * memory descriptor @p data_desc, @p axis, and @p group_size. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int axis, + mkldnn_dim_t group_size); + +/** Initializes a @p shuffle_desc for backward propagation using memory + * descriptor @p diff_data_desc, @p axis, and @p group_size. + * + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, + const mkldnn_memory_desc_t *diff_data_desc, int axis, + mkldnn_dim_t group_size); + +/** @} */ + +/** @addtogroup c_api_eltwise Eltwise + * A primitive to compute element-wise operations like parametric rectifier + * linear unit (ReLU). + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * @warning Because the original src is required for backward pass, in-place + * forward pass in general cannot be applied during training. However, for some + * kinds of element-wise operations (namely ReLU with alpha parameter equals 0), + * dst and src can be interchangeable for the backward pass, which enables + * performing in-place forward even for training. + * + * @{ */ + +/** Initializes an @p eltwise_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and + * @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + float alpha, float beta); + +/** Initializes an @p eltwise_desc for backward propagation using @p alg_kind + * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the + * @p alpha and @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, float alpha, float beta); + +/** @} */ + +/** @addtogroup c_api_softmax Softmax + * A primitive to perform softmax. + * + * \f[dst[u][c][in] = + * \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])} + * {\sum\limits_{c}\{\exp(src[ou][c][in]) + * - \max\limits_{c}(src[ou][c][in])\}},\f] + * + * where \f$ou, iu\f$ are outer and inner sizes repectively, defined + * by @p data_desc.dims and @p softmax_axis. + * @{ */ + +/** Initializes a @p softmax_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference) + * and memory descriptor @p data_desc. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** Initializes a @p softmax_desc for backward propagation using memory + * descriptors @p diff_desc and @p data_desc. + * + * Inputs: + * - dst (#mkldnn_query_dst_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, + const mkldnn_memory_desc_t *diff_desc, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** @} */ + +/** @addtogroup c_api_pooling Pooling + * A primitive to perform max or average pooling. + * + * Max pooling: + * \f[dst[n][oc][oh][ow] = + * \max\limits_{kw,kh} + * (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f] + * + * Average pooling: + * \f[dst[n][oc][oh][ow] = + * \frac{1}{KW \cdot KH}\sum\limits_{kw,kh} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f] + * + * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and + * output spatial dimensions are calculated similarly to how they are done in + * convolution. + * + * During training, max pooling requires a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to + * save indices where maximum was found. The workspace layout is opaque, and + * the indices cannot be restored from it. However, one can use backward + * pooling to perform up-sampling (used in some detection topologies). + * + * @{ */ + +/** Initializes a pooling descriptor @p pool_desc for forward propagation using + * @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling + * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l, + * @p padding_r, and @p padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max and + * @p prop_kind = #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a pooling descriptor @p pool_desc for backward propagation + * using @p alg_kind, memory descriptors, and pooling parameters in the spatial + * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p + * padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_lrn LRN + * A primitive to perform local response normalization (LRN) across or within + * channels. + * + * LRN accross channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c+i][h][w])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * LRN within channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c][h+i][w+i])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * where \f$n_{l}\f$ is the @p local_size. + * + * During training, LRN might or might not require a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The + * behavior is implementation specific. Optimized implementations typically + * require a workspace and use it to save some intermediate results from the + * forward pass that accelerate computations on the backward pass. + * + * To check whether a workspace is required, query the LRN primitive descriptor + * for the workspace (#mkldnn_query_workspace_md). Success indicates that the + * workspace is required and its description will be returned. + * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd + * + * @{ */ + +/** Initializes an @p lrn_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind, memory descriptor @p data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + mkldnn_dim_t local_size, float alpha, float beta, float k); + +/** Initializes an @p lrn_desc for backward propagation using @p alg_kind, + * memory descriptors @p data_desc and @p diff_data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, mkldnn_dim_t local_size, + float alpha, float beta, float k); + +/** @} */ + +/** @addtogroup c_api_batch_normalization Batch Normalization + * A primitive to perform batch normalization. + * + * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]} + * {\sqrt{\sigma[c] + eps}} + \beta[c],\f] + * + * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and, + * + * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$, + * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn} + * (src[n][c][h][w] - \mu[c])^2\f$, + * + * and @c eps is a constant to improve numerical stability. + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * Batch normalization supports different flavors controlled by + * mkldnn_batch_normalization_desc_t. For example, batch normalization can + * compute the mean and variance on its own or take them as inputs. It can + * either perform scaling and shifting using gamma and beta parameters or not. + * Optionally it can also perform a fused ReLU, which in case of training would + * also require a workspace. + * + * @sa mkldnn_batch_normalization_desc_t + * @{ */ + +/** Initializes a batch normalization descriptor @p bnrm_desc for forward + * propagation using @p prop_kind (possible values are + * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor + * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit + * flags of type mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - variance (#mkldnn_query_src_md, 2), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - mean (#mkldnn_query_dst_md, 1), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * @p prop_kind = #mkldnn_forward_training + * - variance (#mkldnn_query_dst_md, 2), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * + * @note In-place operation is supported; that is, dst points to the same memory + * as src. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** Initializes a batch normalization descriptor @p bnrm_desc for backward + * propagation with respect to data and scale-shift parameters using memory + * descriptors @p data_desc and @p diff_data_desc, normalization parameter + * @p epsilon, and @p flags set using bit flags of type + * mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1) + * - variance (#mkldnn_query_src_md, 2) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * - diff_scale_and_shift (#mkldnn_query_diff_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_backward + * + * @note in-place operation is supported, + * i.e. diff_src points to the same memory as diff_dst. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** @} */ + +/** @addtogroup c_api_inner_product Inner product + * A primitive to compute an inner product. + * + * Inner product layer is also known as fully connected layer. + * With spatial dimension: + * + * \f[dst[n][oc] = \sum\limits_{ic, kh, kw} + * src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw] + * + bias[oc]\f] + * @{ */ + +/** Initializes an inner product descriptor @p ip_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference) and memory descriptors. In order to create an + * inner product without bias, @p bias_desc should be either @c NULL or a + * pointer to a descriptor with memory format kind equals + * #mkldnn_format_kind_undef. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init( + mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to data using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to weights using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** @} */ + +/** @addtogroup c_api_rnn RNN + * A primitive to compute the common recurrent layer. + * @todo add additional description for the group + * @{ */ + +/** + * Initializes a recurrent cell descriptor @p rnn_cell_desc + * using @p rnn_cell_desc, @p kind (possible values are + * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and + * #mkldnn_gru_linear_before_reset), + * @p f (possible values are #mkldnn_eltwise_relu and + * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping. + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init( + mkldnn_rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, + unsigned int flags, float alpha, float clipping); + +/** Returns the number of gates of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_gates_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Returns the number of states of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_states_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Sets quantization @p scale and @p shift for RNN data tensors. + * For performance reasons, low precision configuration of RNN primitive + * expects input activations to have unsigned int8 data type. Scale and shift + * used to quantize floating point data to unsigned integer must be passed to + * RNN primitive using attributes. + * Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // activations quantization parameters + * float scale = ..., shift = ..; + * + * mkldnn_primitive_attr_t rnn_attr; + * // create default attributes + * mkldnn_primitive_attr_create(&rnn_attr); + * + * // set scale and shift for int8 quantization of activation + * mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * Quantization scale and shift are common for src_layer, src_iter, + * dst_iter and dst_layer. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams( + mkldnn_primitive_attr_t attr, const float scale, const float shift); + +/** Sets quantization scales @p weights_scales for RNN weights tensors. + * Low precision configuration of RNN primitive expects input weights to have + * signed int8 data type. Scales used to quantize floating point data + * to signed integer must be passed to RNN primitive using attributes. + * The @p mask argument defines correspondence between output tensor dimensions + * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use + * dedicated scaling factor for each slice of the output tensor over i-th + * dimension. Set @p mask to 0 to use common scaling factor for the whole output + * tensor. Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // unique output scales per output channel + * float weights_scales[dic * n_gates] = { ... }; + * // mask that specifies last two dimensions of ldigo format + * int mask = 0x3; + * + * mkldnn_primitive_attr_t attr; + * // create default attributes + * mkldnn_primitive_attr_create(&attr); + * + * // set output channel-wise weights scales + * mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask, + * weights_scales); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * The dimension order is always native and does not depend on the actual + * layout used. For example, 5 dimensional weights always have + * (l, d, i, g, o) logical dimension ordering. + * @note + * Quantization sales are common for weights_layer and weights_iteration + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is user's responsibility + * to set proper values. The following formula must be held: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams ( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *weights_scales); + +/** Initializes a rnn descriptor @p rnn_desc for forward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * @note If @p prop_kind equals #mkldnn_forward_training, you must query a + * workspace memory descriptor before creating the primitive. + * + * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be + * @c NULL or point to a zero memory descriptor, which would indicate that the + * RNN primitive should not use them. + * + * @note All memory descriptors except @p src_iter_desc are allowed to be + * initialized with #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * + * Outputs: + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p prop_kind equals #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc); + +/** Initializes a rnn descriptor @p rnn_desc for backward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * + * @note All memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * @p src_iter_desc (simultaneously with @p diff_src_iter_desc), + * @p bias_desc (simultaneously with @p diff_bias_desc), and + * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to + * either be @c NULL or point to a zero memory descriptor, which would indicate + * that the RNN primitive should not use them. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - diff_dst_layer (#mkldnn_query_diff_dst_md, 0) + * - diff_dst_iter (#mkldnn_query_diff_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0) + * + * Outputs: + * - diff_src_layer (#mkldnn_query_diff_src_md, 0) + * - diff_src_iter (#mkldnn_query_diff_src_md, 1), if used + * - diff_weights_layer (#mkldnn_query_diff_weights_md, 0) + * - diff_weights_iter (#mkldnn_query_diff_weights_md, 1) + * - diff_bias (#mkldnn_query_diff_weights_md, 2), if used + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc, + const mkldnn_memory_desc_t *diff_src_layer_desc, + const mkldnn_memory_desc_t *diff_src_iter_desc, + const mkldnn_memory_desc_t *diff_weights_layer_desc, + const mkldnn_memory_desc_t *diff_weights_iter_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_layer, + const mkldnn_memory_desc_t *diff_dst_iter_desc); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_engine Engine operations + * @{ */ + +/** Returns the number of engines of a particular @p kind. */ +size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind); + +/** Creates an @p engine of particular @p kind and @p index. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, + mkldnn_engine_kind_t kind, size_t index); + +/** Returns the kind of an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine, + mkldnn_engine_kind_t *kind); + +/** Destroys an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_stream Execution stream operations + * @{ */ + +/** Creates an execution @p stream for @p engine and with @p flags. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, + mkldnn_engine_t engine, unsigned flags); + +/** Destroys an execution @p stream. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream); + +/** @} */ + +/** @addtogroup c_api_service Service functions + * @{ */ + +/** Sets verbosity level (print information to stdout). + * Possible levels are: + * - 0 -- no verbose output (default) + * - 1 -- primitive information at execution + * - 2 -- primitive information at creation and execution + * + * @note + * Dumping information might affect performance. + * This setting overrides the MKLDNN_VERBOSE environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level); + +/** Enables or disables dumping of JIT-generated code. + * The enable parameter can be: + * - 0 -- disable + * - any other value -- enable + * + * @note + * This setting overrides the MKLDNN_JIT_DUMP environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int enable); + +/** Gets library version information. + * Version information includes: + * - major -- major version number + * - minor -- minor version number + * - patch -- patch release number + * - hash -- git commit hash */ +const mkldnn_version_t MKLDNN_API *mkldnn_version(); + +/** @} */ + +/** @addtogroup c_api_blas BLAS functions + * A subset of Basic Linear ALgebra (BLAS) functions to perform + * matrix-matrix multiplication. + * @{ */ + +/** SGEMM performs a matrix-matrix multiplication operation defined as + * + * C := alpha*op( A )*op( B ) + beta*C + * + * where + * - op( X ) is one of op( X ) = X or op( X ) = X**T, + * - alpha and beta are scalars, + * - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix + * and C an m by n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different from the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_sgemm( + const char *transa, const char *transb, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, const float *A, const mkldnn_dim_t *lda, + const float *B, const mkldnn_dim_t *ldb, + const float *beta, float *C, const mkldnn_dim_t *ldc); + +/** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication + * operation and add the result to a scalar-matrix product. For the final + * result, a vector is added to each row or column of the output matrix. + * The operation is defined as: + * + * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset + * + * where + * - op( X ) = X or op( X ) = X**T, + * - A_offset is an m-by-k matrix with every element equal to the value oa, + * - B_offset is an k-by-n matrix with every element equal to the value ob, + * - C_offset is an m-by-n matrix defined by the oc array, size len: + * - if offsetc = F: len must be at least 1 + * - if offsetc = C: len must be at least max(1, m) + * - if offsetc = R: len must be at least max(1, n) + * - alpha and beta are scalars, and A, B and C are matrices, with op( A ) + * an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different compared with the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const uint8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); + +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const int8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); +/** @} */ + +/** @} */ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp new file mode 100644 index 0000000000..581400a013 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp @@ -0,0 +1,2615 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_HPP +#define MKLDNN_HPP + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include +#include +#include +#include +#include +#include + +#include "mkldnn.h" +#endif + +namespace mkldnn { + +/// @addtogroup cpp_api C++ API +/// @{ + +/// @addtogroup cpp_api_utils Utils +/// @{ + +/// A class that provides the destructor for an Intel(R) MKL-DNN C handle +template class handle_traits {}; + +/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base +/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and +/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class +/// can be passed by value. This class enables wrapping: +/// - Newly constructed handles. +/// @n In this case, the constructed handle uses reference counting provided +/// by @p std::shared_ptr with a proper deleter function specified through +/// the @p handle_traits class. +/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for +/// example, through mkldnn_primitive_get_primitive_desc()). +/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a +/// deleter because it is assumed that the handle wrapper for the original +/// object deletes the handle (this model is similar to @p std::weak_ptr). +template > class handle { +private: + std::shared_ptr::type> _data; + handle(const handle &&) = delete; + handle &operator=(const handle &&other) = delete; +protected: + bool operator==(const T other) const { return other == _data.get(); } + bool operator!=(const T other) const { return !(*this == other); } +public: + /// Constructs a C handle wrapper. + /// @param t The C handle to wrap. + /// @param weak A flag to specify whether to construct a weak wrapper. + handle(T t = 0, bool weak = false): _data(0) { + reset(t, weak); + } + + handle(const handle &other): _data(other._data) {} + handle &operator=(const handle &other) { + _data = other._data; + return *this; + } + /// Resets the value of a C handle. + /// @param t The new value of the C handle. + /// @param weak A flag to specify whether the wrapper should be weak. + void reset(T t, bool weak = false) { + auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; + _data.reset(t, weak ? dummy_destructor : traits::destructor); + } + + /// Returns the value of the underlying C handle. + T get() const { return _data.get(); } + + bool operator==(const handle &other) const { return other._data.get() == _data.get(); } + bool operator!=(const handle &other) const { return !(*this == other); } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_memory_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_desc_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy; +}; +#endif + +struct memory; +struct primitive_desc; + +/// Base class for all computational primitives. +class primitive: public handle { + friend struct error; + friend struct stream; + using handle::handle; +public: + /// A proxy to C primitive kind enum + enum class kind { + undefined_primitive = mkldnn_undefined_primitive, + reorder = mkldnn_reorder, + concat = mkldnn_concat, + sum = mkldnn_sum, + convolution = mkldnn_convolution, + deconvolution = mkldnn_deconvolution, + shuffle = mkldnn_shuffle, + eltwise = mkldnn_eltwise, + softmax = mkldnn_softmax, + pooling = mkldnn_pooling, + lrn = mkldnn_lrn, + batch_normalization = mkldnn_batch_normalization, + inner_product = mkldnn_inner_product, + rnn = mkldnn_rnn, + }; + + primitive(const_mkldnn_primitive_desc_t c_pd); + primitive(const primitive_desc &pd); + + /// Returns the descriptor of the underlying C API primitive. + inline const_mkldnn_primitive_desc_t get_primitive_desc() const; + // TODO: use the C++ API wrapper structure. + + void execute(struct stream &astream, + const std::unordered_map &args) const; +}; + +inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { + return static_cast(akind); +} +/// Intel(R) MKL-DNN exception class. +/// +/// This class captures the status returned by the failed C API function, error +/// message, and, optionally, handle of the primitive that caused the error. +struct error: public std::exception { + mkldnn_status_t status; + const char *message; + + /// Constructs an error instance. + /// + /// @param astatus The error status returned by the C API. + /// @param amessage The error message. + error(mkldnn_status_t astatus, const char *amessage) + : status(astatus), message(amessage) {} + + /// A convenience function for wrapping calls to the C API. Checks the + /// return status and throws an #error in case of failure. + /// + /// @param status The error status returned by the C API. + /// @param message The error message. + static void wrap_c_api(mkldnn_status_t status, const char *message) { + if (status != mkldnn_success) + throw error(status, message); + } +}; + +const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { + const_mkldnn_primitive_desc_t pd; + error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), + "could not get primitive descriptor by primitive"); + return pd; +} +/// @} + +/// @addtogroup cpp_api_enums Common data types and enumerations +/// A proxy to @ref c_api_types in @ref c_api. +/// +/// @{ + +enum scratchpad_mode { + scratchpad_mode_library = mkldnn_scratchpad_mode_library, + scratchpad_mode_user = mkldnn_scratchpad_mode_user, +}; + +inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { + return static_cast(mode); +} + +enum padding_kind { + zero = mkldnn_padding_zero +}; + +inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { + return static_cast(kind); +} + +enum prop_kind { + forward_training = mkldnn_forward_training, + forward_scoring = mkldnn_forward_scoring, + forward_inference = mkldnn_forward_inference, + forward = mkldnn_forward, + backward = mkldnn_backward, + backward_data = mkldnn_backward_data, + backward_weights = mkldnn_backward_weights, + backward_bias = mkldnn_backward_bias +}; + +inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { + return static_cast(kind); +} + +enum algorithm { + algorithm_undef = mkldnn_alg_kind_undef, + convolution_auto = mkldnn_convolution_auto, + convolution_direct = mkldnn_convolution_direct, + convolution_winograd = mkldnn_convolution_winograd, + deconvolution_direct = mkldnn_deconvolution_direct, + deconvolution_winograd = mkldnn_deconvolution_winograd, + eltwise_relu = mkldnn_eltwise_relu, + eltwise_tanh = mkldnn_eltwise_tanh, + eltwise_elu = mkldnn_eltwise_elu, + eltwise_square = mkldnn_eltwise_square, + eltwise_abs = mkldnn_eltwise_abs, + eltwise_sqrt = mkldnn_eltwise_sqrt, + eltwise_linear = mkldnn_eltwise_linear, + eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, + eltwise_soft_relu = mkldnn_eltwise_soft_relu, + eltwise_logistic = mkldnn_eltwise_logistic, + lrn_across_channels = mkldnn_lrn_across_channels, + lrn_within_channel = mkldnn_lrn_within_channel, + pooling_max = mkldnn_pooling_max, + pooling_avg = mkldnn_pooling_avg, + pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, + pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, + vanilla_rnn = mkldnn_vanilla_rnn, + vanilla_lstm = mkldnn_vanilla_lstm, + vanilla_gru = mkldnn_vanilla_gru, + gru_linear_before_reset = mkldnn_gru_linear_before_reset +}; + +inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { + return static_cast(aalgorithm); +} + +enum batch_normalization_flag { + use_global_stats = mkldnn_use_global_stats, + use_scale_shift = mkldnn_use_scaleshift, + fuse_bn_relu = mkldnn_fuse_bn_relu +}; + +inline mkldnn_batch_normalization_flag_t convert_to_c( + batch_normalization_flag aflag) { + return static_cast(aflag); +} + +enum rnn_direction { + unidirectional_left2right = mkldnn_unidirectional_left2right, + unidirectional_right2left = mkldnn_unidirectional_right2left, + unidirectional = mkldnn_unidirectional, + bidirectional_concat = mkldnn_bidirectional_concat, + bidirectional_sum = mkldnn_bidirectional_sum, +}; + +inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { + return static_cast(adir); +} + +enum query { + undef = mkldnn_query_undef, + + query_engine = mkldnn_query_engine, + primitive_kind = mkldnn_query_primitive_kind, + + num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, + num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, + + time_estimate_f64 = mkldnn_query_time_estimate_f64, + memory_consumption_s64 = mkldnn_query_memory_consumption_s64, + + query_scratchpad_engine = mkldnn_query_scratchpad_engine, + + impl_info_str = mkldnn_query_impl_info_str, + + op_d = mkldnn_query_op_d, + convolution_d = mkldnn_query_convolution_d, + deconvolution_d = mkldnn_query_deconvolution_d, + shuffle_d = mkldnn_query_shuffle_d, + eltwise_d = mkldnn_query_eltwise_d, + softmax_d = mkldnn_query_softmax_d, + pooling_d = mkldnn_query_pooling_d, + lrn_d = mkldnn_query_lrn_d, + batch_normalization_d = mkldnn_query_batch_normalization_d, + inner_product_d = mkldnn_query_inner_product_d, + rnn_d = mkldnn_query_rnn_d, + + src_md = mkldnn_query_src_md, + diff_src_md = mkldnn_query_diff_src_md, + weights_md = mkldnn_query_weights_md, + diff_weights_md = mkldnn_query_diff_weights_md, + dst_md = mkldnn_query_dst_md, + diff_dst_md = mkldnn_query_diff_dst_md, + workspace_md = mkldnn_query_workspace_md, + scratchpad_md = mkldnn_query_scratchpad_md, +}; + +inline mkldnn_query_t convert_to_c(query aquery) { + return static_cast(aquery); +} + +/// @} + +/// @addtogroup cpp_api_attr Attributes +/// An extension for controlling primitive behavior. +/// +/// @sa @ref c_api_attributes in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_post_ops_destroy; +}; +#endif + +struct post_ops: public handle { + post_ops() { + mkldnn_post_ops_t result; + error::wrap_c_api(mkldnn_post_ops_create(&result), + "could not create post operation sequence"); + reset(result); + } + + int len() const { return mkldnn_post_ops_len(get()); } + + primitive::kind kind(int index) const { + error::wrap_c_api( + index < len() ? mkldnn_success : mkldnn_invalid_arguments, + "post_ops index is out of range"); + return static_cast(mkldnn_post_ops_get_kind(get(), + index)); + } + + void append_sum(float scale = 1.) { + error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), + "could not append sum"); + } + + void get_params_sum(int index, float &scale) const { + error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), + "could not get sum params"); + } + + void append_eltwise(float scale, algorithm alg, float alpha, + float beta) { + error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale, + convert_to_c(alg), alpha, beta), + "could not append eltwise"); + } + + void get_params_eltwise(int index, float &scale, algorithm &alg, + float &alpha, float &beta) const { + mkldnn_alg_kind_t c_alg; + error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index, + &scale, &c_alg, &alpha, &beta), + "could not get eltwise params"); + alg = static_cast(c_alg); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_attr_destroy; +}; +#endif + +struct primitive_attr: public handle { + primitive_attr() { + mkldnn_primitive_attr_t result; + error::wrap_c_api(mkldnn_primitive_attr_create(&result), + "could not create a primitive attr"); + reset(result); + } + + scratchpad_mode get_scratchpad_mode() const { + mkldnn_scratchpad_mode_t result; + error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode( + get(), &result), "could not get scratchpad mode"); + return scratchpad_mode(result); + } + + void set_scratchpad_mode(scratchpad_mode mode) { + error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode( + get(), mkldnn::convert_to_c(mode)), + "could not set scratchpad mode"); + } + + void get_output_scales(int &mask, std::vector &scales) const + { + mkldnn_dim_t count; + int c_mask; + const float *c_scales; + error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(), + &count, &c_mask, &c_scales), + "could not get int output scales"); + scales.resize(count); + + mask = c_mask; + for (mkldnn_dim_t c = 0; c < count; ++c) + scales[c] = c_scales[c]; + } + + void set_output_scales(int mask, const std::vector &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(), + (mkldnn_dim_t)scales.size(), mask, &scales[0]), + "could not set int output scales"); + } + + const post_ops get_post_ops() const { + post_ops result; + const_mkldnn_post_ops_t c_result; + error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), + "could not get post operation sequence"); + result.reset(const_cast(c_result), true); + return result; + } + + void set_post_ops(post_ops ops) { + error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), + "could not set post operation sequence"); + } + + void set_rnn_data_qparams(const float scale, const float shift) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(), + scale, shift), "could not set rnn data int scale/shift"); + } + + void set_rnn_weights_qparams(int mask, const std::vector &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(), + (int)scales.size(), mask, &scales[0]), + "could not set rnn weights int scales"); + } +}; + +/// @} + +/// @addtogroup cpp_api_engine Engine +/// Engine operations. +/// +/// @sa @ref c_api_engine in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_engine_destroy; +}; +#endif + +/// An execution engine. +struct engine: public handle { + friend class primitive; + // gcc bug??? using handle::handle; + + /// Kinds of engines. + enum kind { + /// An unspecified engine + any = mkldnn_any_engine, + /// CPU engine + cpu = mkldnn_cpu, + }; + + /// Returns the number of engines of a certain kind. + /// + /// @param akind The kind of engines to count. + + static size_t get_count(kind akind) { + return mkldnn_engine_get_count(convert_to_c(akind)); + } + + /// Constructs an engine. + /// + /// @param akind The kind of engine to construct. + /// @param index The index of the engine. Must be less than the value + /// returned by #get_count() for this particular kind of engine. + + engine(kind akind, size_t index) { + mkldnn_engine_t aengine; + error::wrap_c_api( + mkldnn_engine_create(&aengine, + convert_to_c(akind), index), + "could not create an engine"); + reset(aengine); + } + + explicit engine(const mkldnn_engine_t& aengine) + : handle(aengine, true) {} + + engine(const handle &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + reset(engine_q, true); + } + + template + static engine query(const primitive_desc &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + + return engine(engine_q); + } + +private: + static mkldnn_engine_kind_t convert_to_c(kind akind) { + return static_cast(akind); + } +}; + +/// @} + +/// @addtogroup cpp_api_stream Stream +/// Execution stream operations +/// +/// @sa @ref c_api_stream in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_stream_destroy; +}; +#endif + +struct stream: public handle { + using handle::handle; + + enum: unsigned { + default_flags = mkldnn_stream_default_flags, + }; + + /// Constructs a stream. + stream(const engine &aengine, + unsigned flags = static_cast(default_flags)) { + mkldnn_stream_t astream; + error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags), + "could not create a stream"); + reset(astream); + } +}; + +/// @} + +/// @addtogroup cpp_api_memory_related Memory and memory related operations +/// @{ + +/// @addtogroup cpp_api_memory Memory +/// A primitive to describe and store data. +/// +/// For more information, refer to @ref c_api_memory in @ref c_api. +/// @{ + +/// Memory that describes the data. +struct memory: public handle { + public: + typedef mkldnn_dim_t dim; + typedef std::vector dims; + + template static void validate_dims(const std::vector &v) { + if (v.size() > MKLDNN_MAX_NDIMS) + throw error(mkldnn_invalid_arguments, "invalid dimensions"); + } + + /// Data type specification. See #mkldnn_data_type_t for a detailed + /// description. + enum data_type { + data_undef = mkldnn_data_type_undef, + f32 = mkldnn_f32, + s32 = mkldnn_s32, + s8 = mkldnn_s8, + u8 = mkldnn_u8, + }; + + /// Memory format tag specification. See #mkldnn_format_tag_t + /// for a detailed description. + enum format_tag { + format_tag_undef = mkldnn_format_tag_undef, + any = mkldnn_format_tag_any, + a = mkldnn_a, + ab = mkldnn_ab, + abc = mkldnn_abc, + abcd = mkldnn_abcd, + abcde = mkldnn_abcde, + abcdef = mkldnn_abcdef, + abdec = mkldnn_abdec, + acb = mkldnn_acb, + acbde = mkldnn_acbde, + acdb = mkldnn_acdb, + acdeb = mkldnn_acdeb, + ba = mkldnn_ba, + bac = mkldnn_bac, + bacd = mkldnn_bacd, + bcda = mkldnn_bcda, + cba = mkldnn_cba, + cdba = mkldnn_cdba, + cdeba = mkldnn_cdeba, + decab = mkldnn_decab, + Abc16a = mkldnn_Abc16a, + ABc16a16b = mkldnn_ABc16a16b, + aBc16b = mkldnn_aBc16b, + ABc16b16a = mkldnn_ABc16b16a, + Abc4a = mkldnn_Abc4a, + aBc4b = mkldnn_aBc4b, + ABc4b16a4b = mkldnn_ABc4b16a4b, + ABc4b4a = mkldnn_ABc4b4a, + ABc8a16b2a = mkldnn_ABc8a16b2a, + ABc8a8b = mkldnn_ABc8a8b, + aBc8b = mkldnn_aBc8b, + ABc8b16a2b = mkldnn_ABc8b16a2b, + ABc8b8a = mkldnn_ABc8b8a, + Abcd16a = mkldnn_Abcd16a, + ABcd16a16b = mkldnn_ABcd16a16b, + aBcd16b = mkldnn_aBcd16b, + ABcd16b16a = mkldnn_ABcd16b16a, + aBCd16b16c = mkldnn_aBCd16b16c, + aBCd16c16b = mkldnn_aBCd16c16b, + Abcd4a = mkldnn_Abcd4a, + aBcd4b = mkldnn_aBcd4b, + ABcd4b16a4b = mkldnn_ABcd4b16a4b, + ABcd4b4a = mkldnn_ABcd4b4a, + aBCd4c16b4c = mkldnn_aBCd4c16b4c, + aBCd4c4b = mkldnn_aBCd4c4b, + ABcd8a16b2a = mkldnn_ABcd8a16b2a, + ABcd8a8b = mkldnn_ABcd8a8b, + aBcd8b = mkldnn_aBcd8b, + ABcd8b16a2b = mkldnn_ABcd8b16a2b, + aBCd8b16c2b = mkldnn_aBCd8b16c2b, + ABcd8b8a = mkldnn_ABcd8b8a, + aBCd8b8c = mkldnn_aBCd8b8c, + aBCd8c16b2c = mkldnn_aBCd8c16b2c, + aBCd8c8b = mkldnn_aBCd8c8b, + Abcde16a = mkldnn_Abcde16a, + ABcde16a16b = mkldnn_ABcde16a16b, + aBcde16b = mkldnn_aBcde16b, + ABcde16b16a = mkldnn_ABcde16b16a, + aBCde16b16c = mkldnn_aBCde16b16c, + aBCde16c16b = mkldnn_aBCde16c16b, + aBCde2c8b4c = mkldnn_aBCde2c8b4c, + Abcde4a = mkldnn_Abcde4a, + aBcde4b = mkldnn_aBcde4b, + ABcde4b4a = mkldnn_ABcde4b4a, + aBCde4b4c = mkldnn_aBCde4b4c, + aBCde4c16b4c = mkldnn_aBCde4c16b4c, + aBCde4c4b = mkldnn_aBCde4c4b, + Abcde8a = mkldnn_Abcde8a, + ABcde8a8b = mkldnn_ABcde8a8b, + aBcde8b = mkldnn_aBcde8b, + ABcde8b16a2b = mkldnn_ABcde8b16a2b, + aBCde8b16c2b = mkldnn_aBCde8b16c2b, + ABcde8b8a = mkldnn_ABcde8b8a, + aBCde8b8c = mkldnn_aBCde8b8c, + aBCde8c16b2c = mkldnn_aBCde8c16b2c, + aBCde8c8b = mkldnn_aBCde8c8b, + aBcdef16b = mkldnn_aBcdef16b, + aBCdef16b16c = mkldnn_aBCdef16b16c, + aBCdef16c16b = mkldnn_aBCdef16c16b, + aBcdef4b = mkldnn_aBcdef4b, + aBCdef4c4b = mkldnn_aBCdef4c4b, + aBCdef8b8c = mkldnn_aBCdef8b8c, + aBCdef8c16b2c = mkldnn_aBCdef8c16b2c, + aBCdef8c8b = mkldnn_aBCdef8c8b, + aBdc16b = mkldnn_aBdc16b, + aBdc4b = mkldnn_aBdc4b, + aBdc8b = mkldnn_aBdc8b, + aBdec16b = mkldnn_aBdec16b, + aBdec4b = mkldnn_aBdec4b, + aBdec8b = mkldnn_aBdec8b, + aBdefc16b = mkldnn_aBdefc16b, + aBdefc4b = mkldnn_aBdefc4b, + aBdefc8b = mkldnn_aBdefc8b, + Acb16a = mkldnn_Acb16a, + Acb4a = mkldnn_Acb4a, + Acb8a = mkldnn_Acb8a, + aCBd16b16c = mkldnn_aCBd16b16c, + aCBde16b16c = mkldnn_aCBde16b16c, + Acdb16a = mkldnn_Acdb16a, + Acdb4a = mkldnn_Acdb4a, + Acdb8a = mkldnn_Acdb8a, + Acdeb16a = mkldnn_Acdeb16a, + Acdeb4a = mkldnn_Acdeb4a, + Acdeb8a = mkldnn_Acdeb8a, + BAc16a16b = mkldnn_BAc16a16b, + BAcd16a16b = mkldnn_BAcd16a16b, + format_tag_last = mkldnn_format_tag_last, + + x = mkldnn_x, + nc = mkldnn_nc, + cn = mkldnn_cn, + ncw = mkldnn_ncw, + nwc = mkldnn_nwc, + nchw = mkldnn_nchw, + nhwc = mkldnn_nhwc, + chwn = mkldnn_chwn, + ncdhw = mkldnn_ncdhw, + ndhwc = mkldnn_ndhwc, + oi = mkldnn_oi, + io = mkldnn_io, + oiw = mkldnn_oiw, + wio = mkldnn_wio, + oihw = mkldnn_oihw, + hwio = mkldnn_hwio, + ihwo = mkldnn_ihwo, + iohw = mkldnn_iohw, + oidhw = mkldnn_oidhw, + dhwio = mkldnn_dhwio, + goiw = mkldnn_goiw, + goihw = mkldnn_goihw, + hwigo = mkldnn_hwigo, + giohw = mkldnn_giohw, + goidhw = mkldnn_goidhw, + tnc = mkldnn_tnc, + ntc = mkldnn_ntc, + ldsnc = mkldnn_ldsnc, + ldigo = mkldnn_ldigo, + ldgoi = mkldnn_ldgoi, + ldgo = mkldnn_ldgo, + nCdhw16c = mkldnn_nCdhw16c, + nCdhw4c = mkldnn_nCdhw4c, + nCdhw8c = mkldnn_nCdhw8c, + nChw16c = mkldnn_nChw16c, + nChw4c = mkldnn_nChw4c, + nChw8c = mkldnn_nChw8c, + nCw16c = mkldnn_nCw16c, + nCw4c = mkldnn_nCw4c, + nCw8c = mkldnn_nCw8c, + IOw16o16i = mkldnn_IOw16o16i, + OIw16i16o = mkldnn_OIw16i16o, + OIw16o16i = mkldnn_OIw16o16i, + Oiw16o = mkldnn_Oiw16o, + OIw4i16o4i = mkldnn_OIw4i16o4i, + OIw4i4o = mkldnn_OIw4i4o, + Oiw4o = mkldnn_Oiw4o, + OIw8i16o2i = mkldnn_OIw8i16o2i, + OIw8i8o = mkldnn_OIw8i8o, + OIw8o16i2o = mkldnn_OIw8o16i2o, + OIw8o8i = mkldnn_OIw8o8i, + Owi16o = mkldnn_Owi16o, + Owi4o = mkldnn_Owi4o, + Owi8o = mkldnn_Owi8o, + IOhw16o16i = mkldnn_IOhw16o16i, + Ohwi16o = mkldnn_Ohwi16o, + Ohwi4o = mkldnn_Ohwi4o, + Ohwi8o = mkldnn_Ohwi8o, + OIhw16i16o = mkldnn_OIhw16i16o, + OIhw16o16i = mkldnn_OIhw16o16i, + Oihw16o = mkldnn_Oihw16o, + OIhw4i16o4i = mkldnn_OIhw4i16o4i, + OIhw4i4o = mkldnn_OIhw4i4o, + Oihw4o = mkldnn_Oihw4o, + OIhw8i16o2i = mkldnn_OIhw8i16o2i, + OIhw8i8o = mkldnn_OIhw8i8o, + OIhw8o16i2o = mkldnn_OIhw8o16i2o, + OIhw8o8i = mkldnn_OIhw8o8i, + Odhwi16o = mkldnn_Odhwi16o, + Odhwi4o = mkldnn_Odhwi4o, + Odhwi8o = mkldnn_Odhwi8o, + OIdhw16i16o = mkldnn_OIdhw16i16o, + OIdhw16o16i = mkldnn_OIdhw16o16i, + Oidhw16o = mkldnn_Oidhw16o, + OIdhw4i4o = mkldnn_OIdhw4i4o, + Oidhw4o = mkldnn_Oidhw4o, + OIdhw8i16o2i = mkldnn_OIdhw8i16o2i, + OIdhw8i8o = mkldnn_OIdhw8i8o, + OIdhw8o8i = mkldnn_OIdhw8o8i, + gIOw16o16i = mkldnn_gIOw16o16i, + gOIw16i16o = mkldnn_gOIw16i16o, + gOIw16o16i = mkldnn_gOIw16o16i, + gOiw16o = mkldnn_gOiw16o, + gOIw4i16o4i = mkldnn_gOIw4i16o4i, + gOIw4i4o = mkldnn_gOIw4i4o, + gOiw4o = mkldnn_gOiw4o, + gOIw8i16o2i = mkldnn_gOIw8i16o2i, + gOIw8i8o = mkldnn_gOIw8i8o, + gOIw8o16i2o = mkldnn_gOIw8o16i2o, + gOIw8o8i = mkldnn_gOIw8o8i, + gOwi16o = mkldnn_gOwi16o, + gOwi4o = mkldnn_gOwi4o, + gOwi8o = mkldnn_gOwi8o, + gIOhw16o16i = mkldnn_gIOhw16o16i, + gOhwi16o = mkldnn_gOhwi16o, + gOhwi4o = mkldnn_gOhwi4o, + gOhwi8o = mkldnn_gOhwi8o, + Goihw16g = mkldnn_Goihw16g, + gOIhw16i16o = mkldnn_gOIhw16i16o, + gOIhw16o16i = mkldnn_gOIhw16o16i, + gOihw16o = mkldnn_gOihw16o, + gOIhw2i8o4i = mkldnn_gOIhw2i8o4i, + gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, + gOIhw4i4o = mkldnn_gOIhw4i4o, + gOIhw4o4i = mkldnn_gOIhw4o4i, + gOihw4o = mkldnn_gOihw4o, + Goihw8g = mkldnn_Goihw8g, + gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, + gOIhw8i8o = mkldnn_gOIhw8i8o, + gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, + gOIhw8o8i = mkldnn_gOIhw8o8i, + gOdhwi16o = mkldnn_gOdhwi16o, + gOdhwi4o = mkldnn_gOdhwi4o, + gOdhwi8o = mkldnn_gOdhwi8o, + gOIdhw16i16o = mkldnn_gOIdhw16i16o, + gOIdhw16o16i = mkldnn_gOIdhw16o16i, + gOidhw16o = mkldnn_gOidhw16o, + gOIdhw4i4o = mkldnn_gOIdhw4i4o, + gOidhw4o = mkldnn_gOidhw4o, + gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i, + gOIdhw8i8o = mkldnn_gOIdhw8i8o, + gOIdhw8o8i = mkldnn_gOIdhw8o8i, + }; + + /// A memory descriptor. + struct desc { + friend struct memory; + /// The underlying C API data structure. + mkldnn_memory_desc_t data; + + /// Constructs a zero memory descriptor + desc(): data() {} + + /// Constructs a memory descriptor. + /// + /// @param adims Data dimensions + /// @param adata_type Data precision/type. + /// @param aformat Data layout format tag. + desc(const dims &adims, data_type adata_type, + format_tag aformat) { + validate_dims(adims); + error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(), + adims.size() == 0 ? nullptr : &adims[0], + convert_to_c(adata_type), convert_to_c(aformat)), + "could not initialize a memory descriptor"); + } + + /// Constructs a memory descriptor from a C API data structure. + /// + /// @param adata A C API #mkldnn_memory_desc_t structure. + desc(const mkldnn_memory_desc_t &adata): data(adata) {} + + /// Constructs a sub-memory descriptor + // + /// @param adims Sizes of a sub-memory + /// @param offsets Offsets of a sub-memory + desc submemory_desc(const dims &adims, const dims &offsets) { + mkldnn_memory_desc_t sub_md; + error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md, + &data, &adims[0], &offsets[0]), + "could not initialize a sub-memory"); + return desc(sub_md); + } + + /// Returns the number of bytes required to allocate the memory described + /// including the padding area. + size_t get_size() const { return mkldnn_memory_desc_get_size(&data); } + + bool operator==(const desc &other) const { + return mkldnn_memory_desc_equal(&data, &other.data) != 0; + } + + bool operator!=(const desc &other) const { return !operator==(other); } + }; + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + /// @param ahandle Native handle. + memory(const desc &md, const engine &aengine, void *ahandle) { + mkldnn_memory_t result; + error::wrap_c_api(mkldnn_memory_create(&result, &md.data, + aengine.get(), ahandle), "could not create a memory"); + reset(result); + } + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + memory(const desc &md, const engine &aengine) + : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {} + + /// Returns the descriptor of the memory. + desc get_desc() const { + const mkldnn_memory_desc_t *cdesc; + error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc), + "could not get memory descriptor from a memory"); + return desc(*cdesc); + } + + /// Returns the engine of the memory. + engine get_engine() const { + mkldnn_engine_t engine_q; + error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q), + "could not get engine from a memory"); + return engine(engine_q); + } + + /// Returns a handle of the data contained in the memory. + /// + /// On the CPU engine, this is a pointer to the allocated memory. + void *get_data_handle() const { + void *handle; + error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), + "could not get native handle"); + return handle; + } + + void set_data_handle(void *handle) const { + error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), + "could not set native handle"); + } + + // Must go away or be private: + static mkldnn_data_type_t convert_to_c(data_type adata_type) { + return static_cast(adata_type); + } + static mkldnn_format_tag_t convert_to_c(format_tag aformat) { + return static_cast(aformat); + } +}; + +inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { + return !(a == b); +} +inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { + return b == a; +} +inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { + return !(a == b); +} + +inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) { + return !(a == b); +} +inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) { + return b == a; +} +inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) { + return !(a == b); +} + +/// @} + +/// @addtogroup cpp_api_reorder Reorder +/// A primitive to copy data between memory formats. +/// +/// @sa @ref c_api_reorder in @ref c_api +/// @{ + +struct reorder : public primitive { + struct primitive_desc : public handle { + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine scratchpad_engine() { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q), + "could not get scratchpad engine from reorder primitive_desc"); + + return engine(engine_q); + } + + engine get_engine() { return engine::query(*this); } + }; + + reorder(const primitive_desc &pd): primitive(pd.get()) {} + + reorder(const memory &src, const memory &dst): + primitive(primitive_desc(src, dst).get()) {} + + void execute(stream astream, memory &src, memory &dst) { + primitive::execute(astream, + {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}}); + } +}; + +/// @} + +/// @addtogroup cpp_api_concat Concat +/// A primitive to concatenate data by arbitrary dimension. +/// +/// @sa @ref c_api_concat in @ref c_api +/// @{ + +struct concat : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + const std::vector &srcs) { + std::vector c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, int concat_dimension, + const std::vector &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + primitive_desc(int concat_dimension, + const std::vector &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, nullptr, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + concat(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @addtogroup cpp_api_sum Sum +/// A primitive to sum data. +/// +/// @sa @ref c_api_sum in @ref c_api +/// @{ + +struct sum : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + const std::vector &srcs) { + std::vector c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, + const std::vector &scales, + const std::vector &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + &scales[0], &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + primitive_desc(const std::vector &scales, + const std::vector &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result, + nullptr, (int)c_api_srcs.size(), &scales[0], + &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + sum(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @} + +/// @addtogroup cpp_api_primitives Primitives +/// @{ + +/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors +/// @{ + +/// A base class for all primitive descriptors. +struct primitive_desc : public handle { + primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, + const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) { + mkldnn_primitive_desc_iterator_t iterator = nullptr; + mkldnn_status_t status = mkldnn_primitive_desc_iterator_create( + &iterator, desc, attr ? attr->get() : nullptr, e.get(), + hint_fwd_pd); + error::wrap_c_api(status, + "could not create a primitive descriptor iterator"); + pd_iterator.reset(iterator); + fetch_impl(); + } + + engine get_engine() { return engine::query(*this); } + + primitive_attr get_primitive_attr() const { + const_mkldnn_primitive_attr_t const_cattr; + error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr), + "could not get attributes"); + mkldnn_primitive_attr_t cattr; + error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr), + "could not clone attributes"); + + primitive_attr attr; + attr.reset(cattr); + return attr; + } + + /// Returns implementation name + const char *impl_info_str() const { + const char *res; + error::wrap_c_api(mkldnn_primitive_desc_query(get(), + mkldnn_query_impl_info_str, 0, &res), + "could not query implementation info string"); + return res; + } + + /// Queries the memory::dim value (same as int64_t) + memory::dim query_s64(query q) const { + memory::dim res; + mkldnn_status_t status = mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(q), 0, &res); + return status == mkldnn_success ? res : 0; + } + + /// Advances the next implementation for the given op descriptor. + /// + /// Returns: + /// - @c true on success + /// - @c false if the last implementation reached, and + /// the primitive descriptor itself is kept unchanged + bool next_impl() { + mkldnn_status_t status = mkldnn_primitive_desc_iterator_next( + pd_iterator.get()); + if (status == mkldnn_iterator_ends) return false; + error::wrap_c_api(status, "primitive descriptor iterator next failed"); + + fetch_impl(); + return true; + } + + /// Queries and returns requested memory descriptor. + memory::desc query_md(query what, int idx = 0) const { + std::vector valid_q{src_md, diff_src_md, weights_md, + diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), + [=](query q) { return what == q; })) + throw error(mkldnn_invalid_arguments, "invalid memory query"); + + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(what), idx); + if (cdesc == nullptr) return memory::desc(); + + return memory::desc(*cdesc); + } + + // register specialized queries, e.g. src_desc() +# define REG_QUERY_MD(name, what, idx) \ + memory::desc name ## _desc() const { return query_md(what ## _md, idx); } + + private: + handle pd_iterator; + void fetch_impl() { + mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch( + pd_iterator.get()); + error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error, + "could not fetch a primitive descriptor from the iterator"); + reset(pd); + } +}; + +/// @} + +/// @addtogroup cpp_api_convolution Convolution +/// A primitive to compute convolution using different algorithms. +/// +/// @sa @ref c_api_convolution in @ref c_api +/// @{ + +struct convolution_forward: public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_data : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_weights : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} +// +/// @addtogroup cpp_api_deconvolution Deconvolution +/// A primitive to compute deconvolution using different algorithms. +/// +/// @sa @ref c_api_deconvolution in @ref c_api +/// @{ + +struct deconvolution_forward: public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_data : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_weights : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_lrn LRN +/// A primitive to perform local response normalization (LRN) across or within +/// channels. +/// +/// @sa @ref c_api_lrn in @ref c_api +/// @{ + +struct lrn_forward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, local_size, alpha, beta, k), + "could not create a lrn forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct lrn_backward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(algorithm aalgorithm, const memory::desc &data_desc, + const memory::desc &diff_data_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, + convert_to_c(aalgorithm), &diff_data_desc.data, + &data_desc.data, local_size, alpha, beta, k), + "could not create a lrn backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_pooling Pooling +/// A primitive to perform max or average pooling. +/// +/// @sa @ref c_api_pooling in @ref c_api +/// @{ + +struct pooling_forward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims kernel, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, &dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a forward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct pooling_backward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::dims &strides, + const memory::dims &kernel, + const memory::dims &padding_l, + const memory::dims &padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data, + convert_to_c(aalgorithm), + &diff_src_desc.data, &diff_dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a backward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_eltwise Eltwise +/// A primitive to compute element-wise operations like parametric rectifier +/// linear unit (ReLU). +/// +/// @sa @ref c_api_eltwise in @ref c_api +/// @{ + +struct eltwise_forward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + template + desc(prop_kind aprop_kind, algorithm alg_kind, + const memory::desc &src_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + mkldnn::convert_to_c(alg_kind), &src_desc.data, + static_cast(alpha), static_cast(beta)), + "could not create a eltwise forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct eltwise_backward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + + template + desc(algorithm alg_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data, + mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, + &data_desc.data, static_cast(alpha), + static_cast(beta)), + "could not create a eltwise backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_softmax Softmax +/// A primitive to perform softmax. +/// +/// @sa @ref c_api_softmax in @ref c_api +/// @{ + +struct softmax_forward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + softmax_axis), + "could not create a softmax forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct softmax_backward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(const memory::desc &diff_desc, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data, + &diff_desc.data, &data_desc.data, softmax_axis), + "could not init a backward softmax descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_batch_norm Batch normalization +/// A primitive to perform batch normalization. +/// +/// @sa @ref c_api_batch_normalization in @ref c_api +/// @{ + +struct batch_normalization_forward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, + unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + static_cast(epsilon), flags), + "could not create a batch normalization forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + + memory::desc mean_desc() const { return stat_desc(mean); } + memory::desc variance_desc() const { return stat_desc(var); } + + private: + enum { mean = 1, var = 2, }; + memory::desc stat_desc(int kind) const { + mkldnn_batch_normalization_desc_t *p; + error::wrap_c_api(mkldnn_primitive_desc_query( + get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), + "could not get a batch-normalization descriptor"); + return query_md(p->flags & use_global_stats ? src_md : dst_md, kind); + } + }; + + batch_normalization_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct batch_normalization_backward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T epsilon, unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + &diff_data_desc.data, &data_desc.data, + static_cast(epsilon), flags), + "could not create a batch normalization backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(mean, src, 1); + REG_QUERY_MD(variance, src, 2); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + batch_normalization_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_inner_product Inner Product +/// A primitive to compute an inner product. +/// +/// @sa @ref c_api_inner_product in @ref c_api +/// @{ + +struct inner_product_forward: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, &bias_desc.data, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, nullptr, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_data: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_data_desc_init(&data, + &diff_src_desc.data, &weights_desc.data, + &diff_dst_desc.data), + "could not create a inner product backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_weights: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + &diff_bias_desc.data, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + nullptr, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_rnn RNN +/// A primitive to compute common recurrent layer. +/// +/// @sa @ref c_api_rnn in @ref c_api +/// @{ + +struct rnn_cell { + struct desc { + mkldnn_rnn_cell_desc_t c_rnn_cell_; + + desc(algorithm kind, algorithm activation_f) { + error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, + mkldnn::convert_to_c(kind), + mkldnn::convert_to_c(activation_f), 0U, 0, 0), + "could not init an rnn cell descriptor"); + } + desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} + + operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } + + algorithm get_cell_kind() const + { return algorithm(c_rnn_cell_.cell_kind); } + algorithm get_activation() const + { return algorithm(c_rnn_cell_.activation_kind); } + + float get_alpha() const { return c_rnn_cell_.alpha; } + void set_alpha(float alpha) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; + c_rnn_cell_.alpha = alpha; + } + + float get_clipping() const { return c_rnn_cell_.clipping; } + void set_clipping(float clipping) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; + c_rnn_cell_.clipping = clipping; + } + + int get_gates_count() const { + return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); + } + int get_state_count() const { + return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); + } + }; +}; + +struct rnn_forward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc + ) { + error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data), + "could not create an RNN forward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + rnn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct rnn_backward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc) { + error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data, + &diff_src_layer_desc.data, &diff_src_iter_desc.data, + &diff_weights_layer_desc.data, + &diff_weights_iter_desc.data, &diff_bias_desc.data, + &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), + "could not create an RNN backward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src_layer, diff_src, 0); + REG_QUERY_MD(diff_src_iter, diff_src, 1); + REG_QUERY_MD(diff_weights_layer, diff_weights, 0); + REG_QUERY_MD(diff_weights_iter, diff_weights, 1); + REG_QUERY_MD(diff_bias, diff_weights, 2); + REG_QUERY_MD(diff_dst_layer, diff_dst, 0); + REG_QUERY_MD(diff_dst_iter, diff_dst, 1); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + // With last iteration (with and without input src_iter) + rnn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_shuffle Shuffle +/// A primitive to shuffle data along the axis. +/// +/// @sa @ref c_api_shuffle in @ref c_api +/// @{ + +struct shuffle_forward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + axis, group_size), + "could not create a shuffle forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct shuffle_backward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(const memory::desc &diff_data_desc, int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data, + &diff_data_desc.data, axis, group_size), + "could not create a shuffle backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const shuffle_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @} Primitives + +/// @} C++ API + +#undef REG_QUERY_MD + +// implementation section +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) { + mkldnn_primitive_t result; + error::wrap_c_api(mkldnn_primitive_create(&result, c_pd), + "could not create a primitive"); + reset(result); +} + +inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {} + +inline void primitive::execute(stream &astream, + const std::unordered_map &args) const { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a: args) + c_args.push_back({a.first, a.second.get()}); + + error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(), + (int)c_args.size(), c_args.data()), + "primitive execution fail"); +} +#endif // DOXYGEN_SHOULD_SKIP_THIS + +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h new file mode 100644 index 0000000000..f4dc2fdfa6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#ifndef MKLDNN_DEBUG_H +#define MKLDNN_DEBUG_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v); +const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v); +const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v); +const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v); +const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v); +const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v); +const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v); +const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v); + +/** Forms a format string for a given memory descriptor. + * + * The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'. + * Here: + * - dt -- data type + * - p -- indicates there is non-trivial padding + * - o -- indicates there is non-trivial padding offset + * - 0 -- indicates there is non-trivial offset0 + * - fmt_kind -- format kind (blocked, wino, etc...) + * - fmt -- extended format string (format_kind specific) + * - extra -- shows extra fields (underspecified) + */ +int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len, + const mkldnn_memory_desc_t *md); + +/** Forms a dimension string for a given memory descriptor. + * + * The format is defined as: 'dim0xdim1x...xdimN + */ +int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len, + const mkldnn_memory_desc_t *md); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h new file mode 100644 index 0000000000..1b6c356982 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h @@ -0,0 +1,1415 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TYPES_H +#define MKLDNN_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include +#include +#endif + +/** @addtogroup c_api C API + * @{ + * + * @addtogroup c_api_types Types + * @{ + * + * @addtogroup c_api_types_generic Generic + * @{ */ + +/** Intel(R) MKL-DNN Version type */ +typedef struct { + int major; + int minor; + int patch; + const char *hash; +} mkldnn_version_t; + +/** Status values returned by Intel(R) MKL-DNN functions. */ +typedef enum { + /** The operation was successful */ + mkldnn_success = 0, + /** The operation failed due to an out-of-memory condition */ + mkldnn_out_of_memory = 1, + /** The operation failed and should be retried */ + mkldnn_try_again = 2, + /** The operation failed because of incorrect function arguments */ + mkldnn_invalid_arguments = 3, + /** The operation failed because a primitive was not ready for execution */ + mkldnn_not_ready = 4, + /** The operation failed because requested functionality is not implemented + */ + mkldnn_unimplemented = 5, + /** Primitive iterator passed over last primitive descriptor */ + mkldnn_iterator_ends = 6, + /** Primitive or engine failed on execution */ + mkldnn_runtime_error = 7, + /** Queried element is not required for given primitive */ + mkldnn_not_required = 8, +} mkldnn_status_t; + +/** Data type specification */ +typedef enum { + /** Undefined data type, used for empty memory descriptors. */ + mkldnn_data_type_undef = 0, + /** 32-bit/single-precision floating point. */ + mkldnn_f32 = 1, + /** 32-bit signed integer. */ + mkldnn_s32 = 2, + /** 8-bit signed integer. */ + mkldnn_s8 = 3, + /** 8-bit unsigned integer. */ + mkldnn_u8 = 4, +} mkldnn_data_type_t; + +/** Memory format kind */ +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_format_kind_undef = 0, + /** Unspecified format. The primitive selects a format automatically. */ + mkldnn_format_kind_any, + /** A tensor in a generic format described by the stride and blocking + * values in each dimension. See #mkldnn_blocking_desc_t for more + * information. */ + mkldnn_blocked, + /** Weights format used in 8bit Winograd convolution */ + mkldnn_format_kind_wino, + /** Packed weights format used in RNN */ + mkldnn_format_kind_rnn_packed, +} mkldnn_format_kind_t; + +/** Memory format tag specification. + * + * Intel MKL-DNN formats describe physical data layout. The physical layout + * is described as a sequence of the dimensions as they are laid out in the + * memory (from the outer-most to the inner-most). Note that this order + * doesn't affect the logical order of the dimensions that is kept in the + * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the + * dimensions is specified by the type of tensor. + * + * For example, CNN 5D tensor always has its logical dimensions in the order + * `(batch, channels, depth, height, width)`, while the physical layout might be + * #mkldnn_ncdhw or #mkldnn_ndhwc: + * + * ~~~cpp + * int batch = 2, channels = 16, depth = 13, height = 13, width = 13; + * + * int ndims = 5; // 5D tensor + * mkldnn_dims_t dims = {batch, channels, depth, height, width}; + * mkldnn_memory_desc_t data_in_ncdhw; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw); + * + * // note that in both cases dims passed are the same + * mkldnn_memory_desc_t data_in_ndhwc; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc); + * ~~~ + * + * The following notation applies to memory format names: + * - @c 'n' denotes the mini-batch dimension + * - @c 'c' denotes a channels dimension + * - When there are multiple channel dimensions (for example, in convolution + * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output + * channels + * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width + * respectively + * - Upper-case letters indicate that the data is laid out in blocks + * for a particular dimension. In such cases, the format name contains both + * upper- and lower-case letters for that dimension with a lower-case letter + * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a + * format where the outermost dimension is mini-batch, followed by the + * channel block number, followed by the spatial height and width, and + * finally followed by 8-element channel blocks. + * + * @note + * Channel designations can be different. For example, both the @c + * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D + * tensor. + * + * @sa @ref understanding_memory_formats + */ +typedef enum { + /** Undefined memory format tag */ + mkldnn_format_tag_undef = 0, + /** Undefined memory format tag. + * The primitive selects a format automatically. */ + mkldnn_format_tag_any, + + /* Semantic agnostic section */ + /* The physical order of dimensions is defined by the permutation of the + * characters, assuming that ab..z defines the natural order. + */ + + /* Plain formats */ + + mkldnn_a, + mkldnn_ab, + mkldnn_abc, + mkldnn_abcd, + mkldnn_abcde, + mkldnn_abcdef, + mkldnn_abdec, + mkldnn_acb, + mkldnn_acbde, + mkldnn_acdb, + mkldnn_acdeb, + mkldnn_ba, + mkldnn_bac, + mkldnn_bacd, + mkldnn_bcda, + mkldnn_cba, + mkldnn_cdba, + mkldnn_cdeba, + mkldnn_decab, + + /* Opaque blocked formats */ + + mkldnn_Abc16a, + mkldnn_ABc16a16b, + mkldnn_aBc16b, + mkldnn_ABc16b16a, + mkldnn_Abc4a, + mkldnn_aBc4b, + mkldnn_ABc4b16a4b, + mkldnn_ABc4b4a, + mkldnn_ABc8a16b2a, + mkldnn_ABc8a8b, + mkldnn_aBc8b, + mkldnn_ABc8b16a2b, + mkldnn_ABc8b8a, + mkldnn_Abcd16a, + mkldnn_ABcd16a16b, + mkldnn_aBcd16b, + mkldnn_ABcd16b16a, + mkldnn_aBCd16b16c, + mkldnn_aBCd16c16b, + mkldnn_Abcd4a, + mkldnn_aBcd4b, + mkldnn_ABcd4b16a4b, + mkldnn_ABcd4b4a, + mkldnn_aBCd4c16b4c, + mkldnn_aBCd4c4b, + mkldnn_ABcd8a16b2a, + mkldnn_ABcd8a8b, + mkldnn_aBcd8b, + mkldnn_ABcd8b16a2b, + mkldnn_aBCd8b16c2b, + mkldnn_ABcd8b8a, + mkldnn_aBCd8b8c, + mkldnn_aBCd8c16b2c, + mkldnn_aBCd8c8b, + mkldnn_Abcde16a, + mkldnn_ABcde16a16b, + mkldnn_aBcde16b, + mkldnn_ABcde16b16a, + mkldnn_aBCde16b16c, + mkldnn_aBCde16c16b, + mkldnn_aBCde2c8b4c, + mkldnn_Abcde4a, + mkldnn_aBcde4b, + mkldnn_ABcde4b4a, + mkldnn_aBCde4b4c, + mkldnn_aBCde4c16b4c, + mkldnn_aBCde4c4b, + mkldnn_Abcde8a, + mkldnn_ABcde8a8b, + mkldnn_aBcde8b, + mkldnn_ABcde8b16a2b, + mkldnn_aBCde8b16c2b, + mkldnn_ABcde8b8a, + mkldnn_aBCde8b8c, + mkldnn_aBCde8c16b2c, + mkldnn_aBCde8c8b, + mkldnn_aBcdef16b, + mkldnn_aBCdef16b16c, + mkldnn_aBCdef16c16b, + mkldnn_aBcdef4b, + mkldnn_aBCdef4c4b, + mkldnn_aBCdef8b8c, + mkldnn_aBCdef8c16b2c, + mkldnn_aBCdef8c8b, + mkldnn_aBdc16b, + mkldnn_aBdc4b, + mkldnn_aBdc8b, + mkldnn_aBdec16b, + mkldnn_aBdec4b, + mkldnn_aBdec8b, + mkldnn_aBdefc16b, + mkldnn_aBdefc4b, + mkldnn_aBdefc8b, + mkldnn_Acb16a, + mkldnn_Acb4a, + mkldnn_Acb8a, + mkldnn_aCBd16b16c, + mkldnn_aCBde16b16c, + mkldnn_Acdb16a, + mkldnn_Acdb4a, + mkldnn_Acdb8a, + mkldnn_Acdeb16a, + mkldnn_Acdeb4a, + mkldnn_Acdeb8a, + mkldnn_BAc16a16b, + mkldnn_BAcd16a16b, + + /** Just a sentinel, not real memory format tag. Must be changed after new + * format tag is added. */ + mkldnn_format_tag_last, + + /* Aliases */ + + mkldnn_x = mkldnn_a, + mkldnn_nc = mkldnn_ab, + mkldnn_cn = mkldnn_ba, + mkldnn_ncw = mkldnn_abc, + mkldnn_nwc = mkldnn_acb, + mkldnn_nchw = mkldnn_abcd, + mkldnn_nhwc = mkldnn_acdb, + mkldnn_chwn = mkldnn_bcda, + mkldnn_ncdhw = mkldnn_abcde, + mkldnn_ndhwc = mkldnn_acdeb, + + mkldnn_oi = mkldnn_ab, + mkldnn_io = mkldnn_ba, + mkldnn_oiw = mkldnn_abc, + mkldnn_wio = mkldnn_cba, + mkldnn_oihw = mkldnn_abcd, + mkldnn_hwio = mkldnn_cdba, + mkldnn_ihwo = mkldnn_bcda, + mkldnn_iohw = mkldnn_bacd, + mkldnn_oidhw = mkldnn_abcde, + mkldnn_dhwio = mkldnn_cdeba, + mkldnn_goiw = mkldnn_abcd, + mkldnn_goihw = mkldnn_abcde, + mkldnn_hwigo = mkldnn_decab, + mkldnn_giohw = mkldnn_acbde, + mkldnn_goidhw = mkldnn_abcdef, + + /** 3D RNN data tensor in the format (seq_length, batch, input channels). */ + mkldnn_tnc = mkldnn_abc, + /** 3D RNN data tensor in the format (batch, seq_length, input channels). */ + mkldnn_ntc = mkldnn_bac, + /** 5D RNN states tensor in the format (num_layers, num_directions, + * num_states, batch, state channels). */ + mkldnn_ldsnc = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * input_channels, num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldigo = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * num_gates, output_channels, input_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgoi = mkldnn_abdec, + /** 4D RNN bias tensor in the format (num_layers, num_directions, + * num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgo = mkldnn_abcd, + + /* Opaque data types, are not to be used explicitly */ + + /* data */ + mkldnn_nCdhw16c = mkldnn_aBcde16b, + mkldnn_nCdhw4c = mkldnn_aBcde4b, + mkldnn_nCdhw8c = mkldnn_aBcde8b, + mkldnn_nChw16c = mkldnn_aBcd16b, + mkldnn_nChw4c = mkldnn_aBcd4b, + mkldnn_nChw8c = mkldnn_aBcd8b, + mkldnn_nCw16c = mkldnn_aBc16b, + mkldnn_nCw4c = mkldnn_aBc4b, + mkldnn_nCw8c = mkldnn_aBc8b, + + /* weights, 3D */ + mkldnn_IOw16o16i = mkldnn_BAc16a16b, + mkldnn_OIw16i16o = mkldnn_ABc16b16a, + mkldnn_OIw16o16i = mkldnn_ABc16a16b, + mkldnn_Oiw16o = mkldnn_Abc16a, + mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b, + mkldnn_OIw4i4o = mkldnn_ABc4b4a, + mkldnn_Oiw4o = mkldnn_Abc4a, + mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b, + mkldnn_OIw8i8o = mkldnn_ABc8b8a, + mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a, + mkldnn_OIw8o8i = mkldnn_ABc8a8b, + mkldnn_Owi16o = mkldnn_Acb16a, + mkldnn_Owi4o = mkldnn_Acb4a, + mkldnn_Owi8o = mkldnn_Acb8a, + + /* weights, 4D */ + mkldnn_IOhw16o16i = mkldnn_BAcd16a16b, + mkldnn_Ohwi16o = mkldnn_Acdb16a, + mkldnn_Ohwi4o = mkldnn_Acdb4a, + mkldnn_Ohwi8o = mkldnn_Acdb8a, + mkldnn_OIhw16i16o = mkldnn_ABcd16b16a, + mkldnn_OIhw16o16i = mkldnn_ABcd16a16b, + mkldnn_Oihw16o = mkldnn_Abcd16a, + mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b, + mkldnn_OIhw4i4o = mkldnn_ABcd4b4a, + mkldnn_Oihw4o = mkldnn_Abcd4a, + mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b, + mkldnn_OIhw8i8o = mkldnn_ABcd8b8a, + mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a, + mkldnn_OIhw8o8i = mkldnn_ABcd8a8b, + + /* weights, 5D */ + mkldnn_Odhwi16o = mkldnn_Acdeb16a, + mkldnn_Odhwi4o = mkldnn_Acdeb4a, + mkldnn_Odhwi8o = mkldnn_Acdeb8a, + mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a, + mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b, + mkldnn_Oidhw16o = mkldnn_Abcde16a, + mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a, + mkldnn_Oidhw4o = mkldnn_Abcde4a, + mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b, + mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a, + mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b, + + /* weights w/ groups, 3D */ + mkldnn_Goiw16g = mkldnn_Abcd16a, + mkldnn_gIOw16o16i = mkldnn_aCBd16b16c, + mkldnn_gOIw16i16o = mkldnn_aBCd16c16b, + mkldnn_gOIw16o16i = mkldnn_aBCd16b16c, + mkldnn_gOiw16o = mkldnn_aBcd16b, + mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c, + mkldnn_gOIw4i4o = mkldnn_aBCd4c4b, + mkldnn_gOiw4o = mkldnn_aBcd4b, + mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c, + mkldnn_gOIw8i8o = mkldnn_aBCd8c8b, + mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b, + mkldnn_gOIw8o8i = mkldnn_aBCd8b8c, + mkldnn_gOwi16o = mkldnn_aBdc16b, + mkldnn_gOwi4o = mkldnn_aBdc4b, + mkldnn_gOwi8o = mkldnn_aBdc8b, + + /* weights w/ groups, 4D */ + mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c, + mkldnn_gOhwi16o = mkldnn_aBdec16b, + mkldnn_gOhwi4o = mkldnn_aBdec4b, + mkldnn_gOhwi8o = mkldnn_aBdec8b, + mkldnn_Goihw16g = mkldnn_Abcde16a, + mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b, + mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c, + mkldnn_gOihw16o = mkldnn_aBcde16b, + mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c, + mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c, + mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b, + mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c, + mkldnn_gOihw4o = mkldnn_aBcde4b, + mkldnn_Goihw8g = mkldnn_Abcde8a, + mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c, + mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b, + mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b, + mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c, + + /* weights w/ groups, 6D */ + mkldnn_gOdhwi16o = mkldnn_aBdefc16b, + mkldnn_gOdhwi4o = mkldnn_aBdefc4b, + mkldnn_gOdhwi8o = mkldnn_aBdefc8b, + mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b, + mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c, + mkldnn_gOidhw16o = mkldnn_aBcdef16b, + mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b, + mkldnn_gOidhw4o = mkldnn_aBcdef4b, + mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c, + mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b, + mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c, +} mkldnn_format_tag_t; + +/** Kinds of padding. Define how to interpret the data in padding regions. */ +typedef enum { + /** The data in padding regions is zero. */ + mkldnn_padding_zero, +} mkldnn_padding_kind_t; + +/** Kinds of propagation. */ +typedef enum { + /* TODO: suggest renames */ + /** Undefined propagation type. */ + mkldnn_prop_kind_undef = 0, + /** Forward data propagation (training mode). In this mode primitives + * perform computations necessary for subsequent backward propagation. */ + mkldnn_forward_training = 64, + /** Forward data propagation (inference mode). In this mode primitives + * perform only computations that are necessary for inference and omit + * computations that are necessary only for backward propagation. */ + mkldnn_forward_inference = 96, + /** Forward data propagation (alias for @c mkldnn_forward_inference) */ + mkldnn_forward_scoring = mkldnn_forward_inference, + /** Forward data propagation (alias for @c mkldnn_forward_training) */ + mkldnn_forward = mkldnn_forward_training, + /** Backward propagation (with respect to all parameters */ + mkldnn_backward = 128, + /** Backward data propagation */ + mkldnn_backward_data = 160, + /** Backward weights propagation */ + mkldnn_backward_weights = 192, + /** Backward bias propagation */ + mkldnn_backward_bias = 193, +} mkldnn_prop_kind_t; + +/** Kinds of primitives. Used to implement a way to extend the library with new + * primitives without changing the ABI. */ +typedef enum { + /** Undefined primitive (XXX: why do we have it?). */ + mkldnn_undefined_primitive, + /** A reorder primitive.*/ + mkldnn_reorder, + /** A shuffle primitive.*/ + mkldnn_shuffle, + /** A (out-of-place) concat primitive. */ + mkldnn_concat, + /** A sum primitive. */ + mkldnn_sum, + /** A convolution primitive. */ + mkldnn_convolution, + /** A deconvolution primitive. */ + mkldnn_deconvolution, + /** An element-wise primitive. */ + mkldnn_eltwise, + /** A Softmax primitive. */ + mkldnn_softmax, + /** A pooling primitive. */ + mkldnn_pooling, + /** An LRN primitive. */ + mkldnn_lrn, + /** An batch normalization primitive. */ + mkldnn_batch_normalization, + /** An inner product primitive. */ + mkldnn_inner_product, + /** A rnn primitive. */ + mkldnn_rnn, +} mkldnn_primitive_kind_t; + +/** Kinds of algorithms. */ +typedef enum { + mkldnn_alg_kind_undef, + /** Direct convolution */ + mkldnn_convolution_direct = 0x1, + /** Winograd convolution */ + mkldnn_convolution_winograd = 0x2, + /** Convolution algorithm(either direct or Winograd) is chosen just in time **/ + mkldnn_convolution_auto = 0x3, + /** Direct deconvolution */ + mkldnn_deconvolution_direct = 0xa, + /** Winograd deconvolution */ + mkldnn_deconvolution_winograd = 0xb, + /** Eltwise: ReLU */ + mkldnn_eltwise_relu = 0x1f, + /** Eltwise: hyperbolic tangent non-linearity (tanh) */ + mkldnn_eltwise_tanh = 0x2f, + /** Eltwise: parametric exponential linear unit (elu) */ + mkldnn_eltwise_elu = 0x3f, + /** Eltwise: square */ + mkldnn_eltwise_square = 0x4f, + /** Eltwise: abs */ + mkldnn_eltwise_abs = 0x5f, + /** Eltwise: square root */ + mkldnn_eltwise_sqrt = 0x6f, + /** Eltwise: linear */ + mkldnn_eltwise_linear = 0x7f, + /** Eltwise: bounded_relu */ + mkldnn_eltwise_bounded_relu = 0x8f, + /** Eltwise: soft_relu */ + mkldnn_eltwise_soft_relu = 0x9f, + /** Eltwise: logistic */ + mkldnn_eltwise_logistic = 0xaf, + /** Max pooling */ + mkldnn_pooling_max = 0x1ff, + /** Average pooling include padding */ + mkldnn_pooling_avg_include_padding = 0x2ff, + /** Average pooling exclude padding */ + mkldnn_pooling_avg_exclude_padding = 0x3ff, + mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding, + /** Local response normalization (LRN) across multiple channels */ + mkldnn_lrn_across_channels = 0xaff, + /** LRN within a single channel */ + mkldnn_lrn_within_channel = 0xbff, + /** RNN cell */ + mkldnn_vanilla_rnn = 0x1fff, + /** LSTM cell */ + mkldnn_vanilla_lstm = 0x2fff, + /** GRU cell */ + mkldnn_vanilla_gru = 0x3fff, + /** GRU cell with linear before reset + * + * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru + * in how the new memory gate is calculated: + * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f] + * Primitive expects 4 biases on input: + * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ + * */ + mkldnn_gru_linear_before_reset = 0x4fff, +} mkldnn_alg_kind_t; + +/** Flags for batch-normalization primititve. */ +typedef enum { + /** Use global statistics + * + * If specified + * - on forward propagation use mean and variance provided by user (input) + * - on backward propagation reduces the amount of computations, since + * mean and variance are considered as constants + * + * If not specified: + * - on forward propagation mean and variance are computed and stored in + * output + * - on backward propagation compute full derivative wrt to data + */ + mkldnn_use_global_stats = 0x1U, + /** Use scale and shift parameters + * + * If specified: + * - on forward propagation use scale and shift (aka scale and bias) for + * the batch normalization results + * - on backward propagation (for prop_kind == #mkldnn_backward) compute + * diff wrt to scale and shift (hence one extra output used) + * + * If no specified: + * - on backward propagation prop_kind == #mkldnn_backward_data has the + * same behavior as prop_kind == #mkldnn_backward + */ + mkldnn_use_scaleshift = 0x2U, + /** Fuse with ReLU + * + * If specified: + * - on inference this option behaves the same as if the primitive were + * fused with ReLU via post ops API + * - on training primitive requires workspace (required to be able to + * perform backward pass) + */ + mkldnn_fuse_bn_relu = 0x4U, +} mkldnn_batch_normalization_flag_t; + +/** @} */ + +/** @addtogroup c_api_types_memory Memory + * @{ */ + +/** Maximum number of dimensions a tensor can have. Only restricts the amount + * of space used for the tensor description. Individual computational + * primitives may support only tensors of certain dimensions. */ +#define MKLDNN_MAX_NDIMS 12 + +/** A type to describe tensor dimension. */ +typedef int64_t mkldnn_dim_t; + +/** A type to describe tensor dimensions. */ +typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS]; + +/** A type to describe strides within a tensor. */ +typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS]; + +/** Generic description of blocked data layout for most memory formats. + * + * @sa @ref understanding_memory_formats */ +typedef struct { + /** The strides between the outermost blocks. + * In case of plain (non-blocked) formats the strides between dimensions. */ + mkldnn_dims_t strides; + /* Innermost section + * ASSUMPTION: the innermost blocks are always dense */ + /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */ + int inner_nblks; + /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */ + mkldnn_dims_t inner_blks; + /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of + * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */ + mkldnn_dims_t inner_idxs; +} mkldnn_blocking_desc_t; + +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_wino_undef = 0, + /** Tensors of weights for 2x3 winograd convolutions. */ + mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, + /** Tensor of weights for 4x3 convolution. */ + mkldnn_wino_wei_OBaaIBOIio +} mkldnn_wino_memory_format_t; + +/** Description of tensor of weights for winograd 2x3 convolution. */ +typedef struct { + mkldnn_wino_memory_format_t wino_format; + int r; + int alpha; + int ic; + int oc; + int ic_block; + int oc_block; + int ic2_block; + int oc2_block; + float adj_scale; + size_t size; +} mkldnn_wino_desc_t; + +typedef enum { + mkldnn_packed_format_undef = 0, + mkldnn_ldigo_p, + mkldnn_ldgoi_p +} mkldnn_rnn_packed_memory_format_t; + +/* Maximum number of parts of RNN weights tensor that require separate + * computation. */ +#define MKLDNN_RNN_MAX_N_PARTS 4 + +/** Description of tensor of packed weights for rnn. */ +typedef struct { + mkldnn_rnn_packed_memory_format_t format; + int n_parts; + int n; + int parts[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + size_t offset_compensation; + size_t size; +} mkldnn_rnn_packed_desc_t; + +typedef enum { + mkldnn_memory_extra_flag_none = 0x0U, + /** Indicates the weights have an additional buffer, that depends on the + * @p compensation_mask. + * + * For instance, in 4D case with the compensation mask equals (1 << 0) + * the additional buffer would consist of OC values: + * O[oc : 0,OC] = + * -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) } + */ + mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U, + mkldnn_memory_extra_flag_scale_adjust = 0x2U, +} mkldnn_memory_extra_flags_t; + +/** Description of extra information stored in memory */ +typedef struct { + /** The flags contain arbitrary extra information, such as compensation. + * @sa mkldnn_memory_extra_flags_t */ + uint64_t flags; + /** Compensation mask */ + int compensation_mask; + /** Scale applied to the data */ + float scale_adjust; + /** For future backwards compatibility */ + char reserved[64]; +} mkldnn_memory_extra_desc_t; + +/** Memory descriptor. The description is based on a number of dimensions, + * dimensions themselves, plus information about elements type and memory + * format. Additionally, contains format-specific descriptions of the data + * layout. */ +typedef struct { + /** Number of dimensions */ + int ndims; + /** Dimensions in the following order: + * - CNN data tensors: mini-batch, channel, spatial + * ({N, C, [[D,] H,] W}) + * - CNN weight tensors: group (optional), output channel, input channel, + * spatial ({[G,] O, I, [[D,] H,] W}) + * - RNN data tensors: time, mini-batch, channels ({T, N, C}) + * or layers, directions, states, mini-batch, channels ({L, D, S, N, C}) + * - RNN weight tensor: layers, directions, input channel, gates, output channels + * ({L, D, I, G, O}). + * + * @note + * The order of dimensions does not depend on the memory format, so + * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc + * the dims for 4D CN data tensor would be {N, C, H, W}. + */ + mkldnn_dims_t dims; + /** Data type of the tensor elements. */ + mkldnn_data_type_t data_type; + + /** Size of the data including padding in each dimension. */ + mkldnn_dims_t padded_dims; + /** Per-dimension offset from the padding to actual data, the top-level + * tensor with offsets applied must lie within the padding area. */ + mkldnn_dims_t padded_offsets; + + /** Offset from memory origin to the current block, non-zero only in + * a description of a memory sub-block. */ + mkldnn_dim_t offset0; + + /** Memory format kind. */ + mkldnn_format_kind_t format_kind; + union { + /** Description of the data layout for memory formats that use + * blocking. */ + mkldnn_blocking_desc_t blocking; + /** Tensor of weights for integer 8bit winograd convolution. */ + mkldnn_wino_desc_t wino_desc; + /** Tensor of packed weights for RNN. */ + mkldnn_rnn_packed_desc_t rnn_packed_desc; + /* ... other descriptions possible */ + } format_desc; + + mkldnn_memory_extra_desc_t extra; +} mkldnn_memory_desc_t; + +/** @struct mkldnn_memory + * An opaque structure to describe a memory. */ +struct mkldnn_memory; + +/** A memory handle. */ +typedef struct mkldnn_memory *mkldnn_memory_t; + +/** A constant memory handle. */ +typedef const struct mkldnn_memory *const_mkldnn_memory_t; + +#define MKLDNN_NATIVE_HANDLE_NONE (NULL) +#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1) + +/** @} */ + +/** @addtogroup c_api_types_op_descs Operation descriptors + * @{*/ + +/** A pointer to any of the operation descriptors. */ +typedef void *mkldnn_op_desc_t; +/** A pointer to any of the operation descriptors (constant variant). */ +typedef const void *const_mkldnn_op_desc_t; + +/** A descriptor of a convolution operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** The kind of the convolution algorithm. Possible values: + * #mkldnn_convolution_direct. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Convolution strides in each spatial dimension. */ + mkldnn_dims_t strides; + /** Convolution dilates in each spatial dimension. */ + mkldnn_dims_t dilates; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_convolution_desc_t; + +/** A descriptor of a deconvolution operation. */ +typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t; + +/** A descriptor of a shuffle operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward_data. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor, + * and source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** axis for shuffling. */ + int axis; + /** number of groups in group convolution */ + mkldnn_dim_t group_size; +} mkldnn_shuffle_desc_t; + +/** A descriptor of a element-wise operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_eltwise. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu, + * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square, + * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear, + * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and + * #mkldnn_eltwise_logistic. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Algorithm specific parameter. + * Accordance table: + * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored + * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_square: @p alpha and @p beta ignored + * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored + * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored + * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift + * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored + * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored + * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored + */ + float alpha, beta; +} mkldnn_eltwise_desc_t; + +/** A descriptor of a Softmax operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_softmax. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training and + * #mkldnn_forward_inference. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and Destination of gradient memory descriptor. */ + mkldnn_memory_desc_t diff_desc; + /** The axis along which to perform the softmax. */ + int softmax_axis; +} mkldnn_softmax_desc_t; + +/** A descriptor of a pooling operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_pooling. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and + * #mkldnn_pooling_avg. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Pooling kernel strides for spatial dimensions. */ + mkldnn_dims_t strides; + /** Pooling kernel spatial dimensions. */ + mkldnn_dims_t kernel; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_pooling_desc_t; + +/** A descriptor of a Local Response Normalization (LRN) operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_lrn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and + * #mkldnn_lrn_across_channels. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** The number of channels to sum over (for cross-channel LRN) or the side + * length of the square region to sum over (for within-channel LRN). */ + mkldnn_dim_t local_size; + /** LRN alpha parameter. */ + float lrn_alpha; + /** LRN beta parameter. */ + float lrn_beta; + /** LRN k parameter. */ + float lrn_k; +} mkldnn_lrn_desc_t; + +/** A descriptor of a Batch Normalization operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_batch_normalization. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Scale and shift data and gradient memory descriptors. + * + * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st + * dimension contains gamma parameter, 2-nd dimension contains beta + * parameter. */ + mkldnn_memory_desc_t data_scaleshift_desc; + mkldnn_memory_desc_t diff_data_scaleshift_desc; + /** Mean and variance data memory descriptors. + * + * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels]. + */ + mkldnn_memory_desc_t mean_desc; + mkldnn_memory_desc_t variance_desc; + /** Batch normalization epsilon parameter. */ + float batch_norm_epsilon; + unsigned flags; +} mkldnn_batch_normalization_desc_t; + +/** A descriptor of an inner product operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_inner_product. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_inner_product_desc_t; + +/** Flags for RNN cell. */ +typedef enum { + mkldnn_rnn_cell_with_relu = 0x1U, + mkldnn_rnn_cell_with_clipping = 0x2U, +} mkldnn_rnn_cell_flags_t; + +typedef struct { + /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn, + * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, + * or #mkldnn_gru_linear_before_reset. */ + mkldnn_alg_kind_t cell_kind; + /** Activation function used. Must be either #mkldnn_eltwise_relu or + * #mkldnn_eltwise_tanh. */ + mkldnn_alg_kind_t activation_kind; + /** RNN cell flags */ + unsigned int flags; + /** @c alpha is a negative slope parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */ + float alpha; + /** clipping parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */ + float clipping; +} mkldnn_rnn_cell_desc_t; + +/** A direction of RNN primitive execution. */ +typedef enum { + /* Unidirectional execution of RNN primitive from left to right. */ + mkldnn_unidirectional_left2right, + /* Unidirectional execution of RNN primitive from right to left. */ + mkldnn_unidirectional_right2left, + /* Bidirectional execution of RNN primitive with concatenation of the + * results. */ + mkldnn_bidirectional_concat, + /* Bidirectional execution of RNN primitive with summation of the + * results. */ + mkldnn_bidirectional_sum, + mkldnn_unidirectional = mkldnn_unidirectional_left2right, +} mkldnn_rnn_direction_t; + +/** A descriptor for an RNN operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_rnn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward. */ + mkldnn_prop_kind_t prop_kind; + /** The RNN cell desc. */ + mkldnn_rnn_cell_desc_t cell_desc; + /** The direction of RNN primitive execution. */ + mkldnn_rnn_direction_t direction; + /** Source layer memory descriptor. */ + mkldnn_memory_desc_t src_layer_desc; + /** Source iteration memory descriptor. */ + mkldnn_memory_desc_t src_iter_desc; + /** Weights layer memory descriptor. */ + mkldnn_memory_desc_t weights_layer_desc; + /** Weights iteration memory descriptor. */ + mkldnn_memory_desc_t weights_iter_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Destination layer memory descriptor. */ + mkldnn_memory_desc_t dst_layer_desc; + /** Destination iter memory descriptor. */ + mkldnn_memory_desc_t dst_iter_desc; + /** Source gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_src_layer_desc; + /** Source gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_src_iter_desc; + /** Weights gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_weights_layer_desc; + /** Weights gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_weights_iter_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_dst_layer_desc; + /** Destination gradient iteration memory descriptor. */ + mkldnn_memory_desc_t diff_dst_iter_desc; +} mkldnn_rnn_desc_t; + +/** @} */ + +/** @addtogroup c_api_engine_types Engine + * @{ */ + +/** @brief Kinds of engines. */ +typedef enum { + /** An unspecified engine. */ + mkldnn_any_engine, + /** CPU engine. */ + mkldnn_cpu, +} mkldnn_engine_kind_t; + +/** @struct mkldnn_engine + * @brief An opaque structure to describe an engine. */ +struct mkldnn_engine; +/** @brief An engine handle. */ +typedef struct mkldnn_engine *mkldnn_engine_t; +#if 0 +/* FIXME: looks like this never happens */ +/** @brief A constant engine handle. */ +typedef const struct mkldnn_engine *const_mkldnn_engine_t; +#endif + +/** @} */ + +/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators + * @{ */ + +/** @struct mkldnn_primitive_desc_iterator + * @brief An opaque structure to describe a primitive descriptor iterator. */ +struct mkldnn_primitive_desc_iterator; + +/** @brief A primitive descriptor iterator handle. */ +typedef struct mkldnn_primitive_desc_iterator + *mkldnn_primitive_desc_iterator_t; + +/** @brief A constant primitive descriptor iterator handle. */ +typedef const struct mkldnn_primitive_desc_iterator + *const_mkldnn_primitive_desc_iterator_t; + +/** @} */ + +/** @addtogroup c_api_primitive_descs Primitive descriptors + * @{ */ + +/** @struct mkldnn_primitive_desc + * @brief An opaque structure to describe a primitive descriptor. */ +struct mkldnn_primitive_desc; + +/** @brief A primitive descriptor handle. */ +typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t; + +/** @brief A constant primitive descriptor handle. */ +typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t; + +/** @} */ + +/** @addtogroup c_api_primitive_attr Primitive descriptor attributes + * @{ */ + +/** Scratchpad mode */ +typedef enum { + /** The library manages scratchpad (default) */ + mkldnn_scratchpad_mode_library, + /** A user shall query and provide the scratchpad memory to primitives */ + mkldnn_scratchpad_mode_user, +} mkldnn_scratchpad_mode_t; + +/** @struct mkldnn_primitive_attr + * @brief An opaque structure for primitive descriptor attributes. + * + * Attributes may contain: + * - output scales (to scale the result prior to storing it to the memory) + */ +struct mkldnn_primitive_attr; + +/** @brief A primitive descriptor attributes handle that controls primitive + * behavior. */ +typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t; + +/** @brief A constant primitive descriptor attributes handle. */ +typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t; + +/** @struct mkldnn_post_ops + * @brief An opaque structure for a chain of post operations. + * + * mkldnn_post_ops can be used to perform some (trivial) operations like + * accumulation or eltwise after certain primitives like convolution. + * + * Post operations might be combined together, making a chain of post + * operations. For instance one can configure convolution followed by + * accumulation followed by eltwise. This might be especially beneficial + * for residual learning blocks. + * + * @warning + * Of course not all combinations are supported, so the user should handle + * errors accordingly. + * + * Supported post operations: + * - accumulation (base primitive: convolution) + * - eltwise (base primitive: convolution) + */ +struct mkldnn_post_ops; + +/** @brief A post operation chain handle. */ +typedef struct mkldnn_post_ops *mkldnn_post_ops_t; + +/** @brief A constant post operation chain handle. */ +typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t; + +/** @} */ + +/** @addtogroup c_api_types_primitive Primitive + * @{ */ + +/** @struct mkldnn_primitive + * An opaque structure to describe a primitive. */ +struct mkldnn_primitive; +/** A primitive handle. */ +typedef struct mkldnn_primitive *mkldnn_primitive_t; +/** A constant primitive handle. */ +typedef const struct mkldnn_primitive *const_mkldnn_primitive_t; + +/** @addtogroup c_api_types_arguments Argument indices + * @{ */ + +#define MKLDNN_ARG_SRC_0 1 +#define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_SRC_LAYER MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0 + +#define MKLDNN_ARG_SRC_1 2 +#define MKLDNN_ARG_SRC_ITER MKLDNN_ARG_SRC_1 + +#define MKLDNN_ARG_DST_0 17 +#define MKLDNN_ARG_DST MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_TO MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_DST_LAYER MKLDNN_ARG_DST_0 + +#define MKLDNN_ARG_DST_1 18 +#define MKLDNN_ARG_DST_ITER MKLDNN_ARG_DST_1 + +#define MKLDNN_ARG_WEIGHTS_0 33 +#define MKLDNN_ARG_WEIGHTS MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_SCALE_SHIFT MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_WEIGHTS_LAYER MKLDNN_ARG_WEIGHTS_0 + +#define MKLDNN_ARG_WEIGHTS_1 34 +#define MKLDNN_ARG_WEIGHTS_ITER MKLDNN_ARG_WEIGHTS_1 + +#define MKLDNN_ARG_BIAS 41 + +#define MKLDNN_ARG_MEAN 49 +#define MKLDNN_ARG_VARIANCE 50 + +#define MKLDNN_ARG_WORKSPACE 64 +#define MKLDNN_ARG_SCRATCHPAD 80 + +#define MKLDNN_ARG_DIFF_SRC_0 129 +#define MKLDNN_ARG_DIFF_SRC MKLDNN_ARG_DIFF_SRC_0 +#define MKLDNN_ARG_DIFF_SRC_LAYER MKLDNN_ARG_DIFF_SRC_0 + +#define MKLDNN_ARG_DIFF_SRC_1 130 +#define MKLDNN_ARG_DIFF_SRC_ITER MKLDNN_ARG_DIFF_SRC_1 + +#define MKLDNN_ARG_DIFF_DST_0 145 +#define MKLDNN_ARG_DIFF_DST MKLDNN_ARG_DIFF_DST_0 +#define MKLDNN_ARG_DIFF_DST_LAYER MKLDNN_ARG_DIFF_DST_0 + +#define MKLDNN_ARG_DIFF_DST_1 146 +#define MKLDNN_ARG_DIFF_DST_ITER MKLDNN_ARG_DIFF_DST_1 + +#define MKLDNN_ARG_DIFF_WEIGHTS_0 161 +#define MKLDNN_ARG_DIFF_WEIGHTS MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_SCALE_SHIFT MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER MKLDNN_ARG_DIFF_WEIGHTS_0 + +#define MKLDNN_ARG_DIFF_WEIGHTS_1 162 +#define MKLDNN_ARG_DIFF_WEIGHTS_ITER MKLDNN_ARG_DIFF_WEIGHTS_1 + +#define MKLDNN_ARG_DIFF_BIAS 169 + +#define MKLDNN_ARG_MULTIPLE_SRC 1024 +#define MKLDNN_ARG_MULTIPLE_DST 2048 + +/** @} */ + +/** An auxiliary structure to specify primitive's inputs/outputs at execution + * + * @warning + * With this API it's impossible to preserve constness of memory, so all + * memories are passed w/o const qualifier. However only memories with + * output semantics might be changed during the execution */ +typedef struct { + int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */ + mkldnn_memory_t memory; /**< Input/output memory */ +} mkldnn_exec_arg_t; + +/** @} */ + +/** @addtogroup c_api_types_query Queries + * @{ */ + +/** Primitive descriptor query specification + * + * For generic function mkldnn_primitive_desc_query(), the type of result must + * agree with the queried argument. The correspondence table: + * Query | type of result + * -------------------------------------------------------------- + * #mkldnn_query_engine | mkldnn_engine_t * + * #mkldnn_query_scratchpad_engine | mkldnn_engine_t * + * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t * + * *_s32 | int * + * *_s64 | mkldnn_dim_t * (same as int64_t *) + * *_f64 | double * + * *_str | const char ** + * #mkldnn_query_op_d | const_mkldnn_op_desc_t * + * *_md | const mkldnn_memory_desc_t ** + * *_${op}_d | const mkldnn_${op}_desc_t ** + * *_pd | const_mkldnn_primitive_desc_t * + * + * @note + * Rule of thumb: all opaque types and structures are returned by + * reference. All numbers are returned by value. + * + * @warning + * All returned references point to constant objects and are valid only + * during the lifetime of the queried primitive descriptor. Returned objects + * must not be destroyed by the user. If you need to keep the object longer + * than the lifetime of the queried primitive descriptor, use + * mkldnn_primitive_desc_clone() to make a copy. */ +typedef enum { + mkldnn_query_undef = 0, /**< no query */ + + mkldnn_query_engine, /**< execution engine */ + mkldnn_query_primitive_kind, /**< primitive kind */ + + mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */ + mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */ + + mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */ + mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra + (scratch) memory, additional to all + inputs and outputs memory (bytes) */ + + mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used + for creating scratchpad memory */ + + mkldnn_query_impl_info_str, /**< implementation name */ + + /* memory and op descriptor section */ + mkldnn_query_some_d = 64, /**< stub */ + mkldnn_query_op_d, /**< op descriptor */ + mkldnn_query_convolution_d, /**< convolution descriptor */ + mkldnn_query_deconvolution_d, /**< deconvolution descriptor */ + mkldnn_query_shuffle_d, /**< shuffle descriptor */ + mkldnn_query_eltwise_d, /**< eltwise descriptor */ + mkldnn_query_softmax_d, /**< softmax descriptor */ + mkldnn_query_pooling_d, /**< pooling descriptor */ + mkldnn_query_lrn_d, /**< lrn descriptor */ + mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */ + mkldnn_query_inner_product_d, /**< inner product descriptor */ + mkldnn_query_rnn_d, /**< rnn descriptor */ + + /* memory descriptor section */ + mkldnn_query_some_md = 128, /**< stub */ + mkldnn_query_src_md, /**< source memory desc */ + mkldnn_query_diff_src_md, /**< source gradient memory desc */ + mkldnn_query_weights_md, /**< weights memory descriptor desc */ + mkldnn_query_diff_weights_md, /**< weights grad. memory desc */ + mkldnn_query_dst_md, /**< destination memory desc */ + mkldnn_query_diff_dst_md, /**< destination grad. memory desc */ + mkldnn_query_workspace_md, /**< workspace memory desc */ + mkldnn_query_scratchpad_md, /**< scratchpad memory desc */ +} mkldnn_query_t; + +/** @} */ + +/** @addtogroup c_api_types_stream Execution stream + * @{ */ + +/** @brief Stream flags. */ +typedef enum { + /** A default stream configuration. */ + mkldnn_stream_default_flags = 0x0U, +} mkldnn_stream_flags_t; + +/** @struct mkldnn_stream + * An opaque structure to describe an execution stream. */ +struct mkldnn_stream; +/** An execution stream handle. */ +typedef struct mkldnn_stream *mkldnn_stream_t; +/** A constant execution stream handle. */ +typedef const struct mkldnn_stream *const_mkldnn_stream_t; + +/** @} */ +/** @} */ +/** @} */ + +#ifdef __cplusplus +} +#endif + + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h new file mode 100644 index 0000000000..a2713deccb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR 0 + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR 90 + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH 0 + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in new file mode 100644 index 0000000000..5ee0126188 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@ + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@ + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@ + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp new file mode 100644 index 0000000000..1a51d8562b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) { + bool args_ok = true + && !any_null(bnrm_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data, backward) + && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto bd = batch_normalization_desc_t(); + bd.primitive_kind = primitive_kind::batch_normalization; + bd.prop_kind = prop_kind; + + bd.data_desc = *data_desc; + bd.diff_data_desc = zero_md(); + if ( one_of(bd.prop_kind,backward_data, backward) ) + bd.diff_data_desc = *diff_data_desc; + + dims_t scaleshift_dims = { 2, data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2, + scaleshift_dims, data_type::f32, mkldnn_nc); + bd.diff_data_scaleshift_desc = zero_md(); + if (bd.prop_kind == backward) { + bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc; + } + + dims_t stats_dims = { data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims, + data_type::f32, mkldnn_x); + bd.variance_desc = bd.mean_desc; + bd.batch_norm_epsilon = epsilon; + + unsigned bnorm_flags = + mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu; + if ((~bnorm_flags & flags) != 0) return invalid_arguments; + + bd.flags = flags; + + bool consistency = true + && utils::one_of(bd.data_desc.ndims, 2, 4, 5); + if (bd.prop_kind == backward_data) + consistency = consistency + && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5) + && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims, + bd.diff_data_desc.ndims); + if (!consistency) return invalid_arguments; + + *bnrm_desc = bd; + return success; +} +} + +status_t mkldnn_batch_normalization_forward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, float epsilon, unsigned flags) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr, + epsilon, flags); +} + +status_t mkldnn_batch_normalization_backward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc, + float epsilon, unsigned flags) { + if (!one_of(prop_kind, backward, backward_data)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc, + epsilon, flags); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp new file mode 100644 index 0000000000..f61410b33c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp @@ -0,0 +1,240 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef BATCH_NORMALIZATION_PD_HPP +#define BATCH_NORMALIZATION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct batch_normalization_fwd_pd_t; + +struct batch_normalization_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::batch_normalization; + + batch_normalization_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , stat_md_(desc_.mean_desc) + , scaleshift_md_(desc_.data_scaleshift_desc) + , ws_md_() + {} + + const batch_normalization_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::batch_normalization_d: + *(const batch_normalization_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common batch_normalization aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return desc_.data_desc.ndims; } + + bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; } + bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; } + bool use_global_stats() const + { return desc_.flags & mkldnn_use_global_stats; } + bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; } + bool with_relu_post_op() const { + const auto &p = this->attr()->post_ops_; + return p.len_ == 1 && p.entry_[0].is_relu(true, true); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + bool is_bwd() const { return !this->is_fwd(); } + bool is_training() const + { return desc_.prop_kind == prop_kind::forward_training; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + batch_normalization_desc_t desc_; + const batch_normalization_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t stat_md_; + memory_desc_t scaleshift_md_; + + memory_desc_t ws_md_; + + void init_default_ws(size_t bits_per_element) { + const auto data_mdw = memory_desc_wrapper(data_md_); + + const dim_t data_nelems = data_mdw.nelems(true); + const dim_t bits_per_byte = 8; + const dims_t ws_sz = { (dim_t)utils::div_up( + data_nelems * bits_per_element, bits_per_byte) }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8, + format_tag::x); + } + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_fwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_fwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input; + if (arg == MKLDNN_ARG_DST) return arg_usage_t::output; + + if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) { + if (stats_is_src()) return arg_usage_t::input; + if (!stats_is_src() && is_training()) return arg_usage_t::output; + return arg_usage_t::unused; + } + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (!stats_is_src() && is_training() && (index == 1 || index == 2)) + return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const + { return stats_is_src() ? src_md(1) : dst_md(1); } + + virtual int n_inputs() const override + { return 1 + 2 * stats_is_src() + use_scaleshift(); } + virtual int n_outputs() const override + { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); } +}; + +struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_bwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_bwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN, + MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override + { return index == 0 ? &diff_scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const { return src_md(1); } + + virtual int n_inputs() const override + { return 4 + use_scaleshift() + fuse_bn_relu(); } + virtual int n_outputs() const override + { return 1 + (desc_.prop_kind == prop_kind::backward); } + +protected: + memory_desc_t diff_data_md_; + memory_desc_t diff_scaleshift_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp new file mode 100644 index 0000000000..3d43a0fbee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp @@ -0,0 +1,550 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_MAPPING_HPP +#define TYPE_MAPPING_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { + +// TODO: autogenerate this + +using dim_t = mkldnn_dim_t; +using dims_t = mkldnn_dims_t; +using stride_t = mkldnn_dim_t; +using strides_t = mkldnn_strides_t; + +using status_t = mkldnn_status_t; +namespace status { + const status_t success = mkldnn_success; + const status_t out_of_memory = mkldnn_out_of_memory; + const status_t try_again = mkldnn_try_again; + const status_t invalid_arguments = mkldnn_invalid_arguments; + const status_t not_ready = mkldnn_not_ready; + const status_t unimplemented = mkldnn_unimplemented; + const status_t iterator_ends = mkldnn_iterator_ends; + const status_t runtime_error = mkldnn_runtime_error; + const status_t not_required = mkldnn_not_required; +} + +using prop_kind_t = mkldnn_prop_kind_t; +namespace prop_kind { + const prop_kind_t undef = mkldnn_prop_kind_undef; + const prop_kind_t forward_training = mkldnn_forward_training; + const prop_kind_t forward_inference = mkldnn_forward_inference; + const prop_kind_t forward_scoring = mkldnn_forward_scoring; + const prop_kind_t forward = mkldnn_forward; + const prop_kind_t backward = mkldnn_backward; + const prop_kind_t backward_data = mkldnn_backward_data; + const prop_kind_t backward_weights = mkldnn_backward_weights; + const prop_kind_t backward_bias = mkldnn_backward_bias; +} + +using alg_kind_t = mkldnn_alg_kind_t; +namespace alg_kind { + const alg_kind_t undef = mkldnn_alg_kind_undef; + const alg_kind_t convolution_auto = mkldnn_convolution_auto; + const alg_kind_t convolution_direct = mkldnn_convolution_direct; + const alg_kind_t convolution_winograd = mkldnn_convolution_winograd; + const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct; + const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd; + const alg_kind_t eltwise_relu = mkldnn_eltwise_relu; + const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh; + const alg_kind_t eltwise_elu = mkldnn_eltwise_elu; + const alg_kind_t eltwise_square = mkldnn_eltwise_square; + const alg_kind_t eltwise_abs = mkldnn_eltwise_abs; + const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt; + const alg_kind_t eltwise_linear = mkldnn_eltwise_linear; + const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu; + const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu; + const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic; + const alg_kind_t pooling_max = mkldnn_pooling_max; + const alg_kind_t pooling_avg = mkldnn_pooling_avg; + const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding; + const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding; + const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels; + const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel; + const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn; + const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm; + const alg_kind_t vanilla_gru = mkldnn_vanilla_gru; + const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset; +} + +using data_type_t = mkldnn_data_type_t; +namespace data_type { + const data_type_t undef = mkldnn_data_type_undef; + const data_type_t f32 = mkldnn_f32; + const data_type_t s32 = mkldnn_s32; + const data_type_t s8 = mkldnn_s8; + const data_type_t u8 = mkldnn_u8; +} + +using scratchpad_mode_t = mkldnn_scratchpad_mode_t; +namespace scratchpad_mode { + const scratchpad_mode_t library = mkldnn_scratchpad_mode_library; + const scratchpad_mode_t user = mkldnn_scratchpad_mode_user; +} + +using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t; +namespace rnn_packed_format { + const rnn_packed_format_t undef = mkldnn_packed_format_undef; + const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p; + const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p; +} + +using format_kind_t = mkldnn_format_kind_t; +namespace format_kind { + const format_kind_t undef = mkldnn_format_kind_undef; + const format_kind_t any = mkldnn_format_kind_any; + const format_kind_t blocked = mkldnn_blocked; + const format_kind_t wino = mkldnn_format_kind_wino; + const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed; +} + +using format_tag_t = mkldnn_format_tag_t; +namespace format_tag { + const format_tag_t undef = mkldnn_format_tag_undef; + const format_tag_t any = mkldnn_format_tag_any; + const format_tag_t a = mkldnn_a; + const format_tag_t ab = mkldnn_ab; + const format_tag_t abc = mkldnn_abc; + const format_tag_t abcd = mkldnn_abcd; + const format_tag_t abcde = mkldnn_abcde; + const format_tag_t abcdef = mkldnn_abcdef; + const format_tag_t abdec = mkldnn_abdec; + const format_tag_t acb = mkldnn_acb; + const format_tag_t acbde = mkldnn_acbde; + const format_tag_t acdb = mkldnn_acdb; + const format_tag_t acdeb = mkldnn_acdeb; + const format_tag_t ba = mkldnn_ba; + const format_tag_t bac = mkldnn_bac; + const format_tag_t bacd = mkldnn_bacd; + const format_tag_t bcda = mkldnn_bcda; + const format_tag_t cba = mkldnn_cba; + const format_tag_t cdba = mkldnn_cdba; + const format_tag_t cdeba = mkldnn_cdeba; + const format_tag_t decab = mkldnn_decab; + const format_tag_t Abc16a = mkldnn_Abc16a; + const format_tag_t ABc16a16b = mkldnn_ABc16a16b; + const format_tag_t aBc16b = mkldnn_aBc16b; + const format_tag_t ABc16b16a = mkldnn_ABc16b16a; + const format_tag_t Abc4a = mkldnn_Abc4a; + const format_tag_t aBc4b = mkldnn_aBc4b; + const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b; + const format_tag_t ABc4b4a = mkldnn_ABc4b4a; + const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a; + const format_tag_t ABc8a8b = mkldnn_ABc8a8b; + const format_tag_t aBc8b = mkldnn_aBc8b; + const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b; + const format_tag_t ABc8b8a = mkldnn_ABc8b8a; + const format_tag_t Abcd16a = mkldnn_Abcd16a; + const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b; + const format_tag_t aBcd16b = mkldnn_aBcd16b; + const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a; + const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c; + const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b; + const format_tag_t Abcd4a = mkldnn_Abcd4a; + const format_tag_t aBcd4b = mkldnn_aBcd4b; + const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b; + const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a; + const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c; + const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b; + const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a; + const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b; + const format_tag_t aBcd8b = mkldnn_aBcd8b; + const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b; + const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b; + const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a; + const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c; + const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c; + const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b; + const format_tag_t Abcde16a = mkldnn_Abcde16a; + const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b; + const format_tag_t aBcde16b = mkldnn_aBcde16b; + const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a; + const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c; + const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b; + const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c; + const format_tag_t Abcde4a = mkldnn_Abcde4a; + const format_tag_t aBcde4b = mkldnn_aBcde4b; + const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a; + const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c; + const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c; + const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b; + const format_tag_t Abcde8a = mkldnn_Abcde8a; + const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b; + const format_tag_t aBcde8b = mkldnn_aBcde8b; + const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b; + const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b; + const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a; + const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c; + const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c; + const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b; + const format_tag_t aBcdef16b = mkldnn_aBcdef16b; + const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c; + const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b; + const format_tag_t aBcdef4b = mkldnn_aBcdef4b; + const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b; + const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c; + const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c; + const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b; + const format_tag_t aBdc16b = mkldnn_aBdc16b; + const format_tag_t aBdc4b = mkldnn_aBdc4b; + const format_tag_t aBdc8b = mkldnn_aBdc8b; + const format_tag_t aBdec16b = mkldnn_aBdec16b; + const format_tag_t aBdec4b = mkldnn_aBdec4b; + const format_tag_t aBdec8b = mkldnn_aBdec8b; + const format_tag_t aBdefc16b = mkldnn_aBdefc16b; + const format_tag_t aBdefc4b = mkldnn_aBdefc4b; + const format_tag_t aBdefc8b = mkldnn_aBdefc8b; + const format_tag_t Acb16a = mkldnn_Acb16a; + const format_tag_t Acb4a = mkldnn_Acb4a; + const format_tag_t Acb8a = mkldnn_Acb8a; + const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c; + const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c; + const format_tag_t Acdb16a = mkldnn_Acdb16a; + const format_tag_t Acdb4a = mkldnn_Acdb4a; + const format_tag_t Acdb8a = mkldnn_Acdb8a; + const format_tag_t Acdeb16a = mkldnn_Acdeb16a; + const format_tag_t Acdeb4a = mkldnn_Acdeb4a; + const format_tag_t Acdeb8a = mkldnn_Acdeb8a; + const format_tag_t BAc16a16b = mkldnn_BAc16a16b; + const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b; + const format_tag_t last = mkldnn_format_tag_last; + + const format_tag_t x = mkldnn_x; + const format_tag_t nc = mkldnn_nc; + const format_tag_t cn = mkldnn_cn; + const format_tag_t ncw = mkldnn_ncw; + const format_tag_t nwc = mkldnn_nwc; + const format_tag_t nchw = mkldnn_nchw; + const format_tag_t nhwc = mkldnn_nhwc; + const format_tag_t chwn = mkldnn_chwn; + const format_tag_t ncdhw = mkldnn_ncdhw; + const format_tag_t ndhwc = mkldnn_ndhwc; + const format_tag_t oi = mkldnn_oi; + const format_tag_t io = mkldnn_io; + const format_tag_t oiw = mkldnn_oiw; + const format_tag_t wio = mkldnn_wio; + const format_tag_t oihw = mkldnn_oihw; + const format_tag_t hwio = mkldnn_hwio; + const format_tag_t ihwo = mkldnn_ihwo; + const format_tag_t iohw = mkldnn_iohw; + const format_tag_t oidhw = mkldnn_oidhw; + const format_tag_t dhwio = mkldnn_dhwio; + const format_tag_t goiw = mkldnn_goiw; + const format_tag_t goihw = mkldnn_goihw; + const format_tag_t hwigo = mkldnn_hwigo; + const format_tag_t giohw = mkldnn_giohw; + const format_tag_t goidhw = mkldnn_goidhw; + const format_tag_t tnc = mkldnn_tnc; + const format_tag_t ntc = mkldnn_ntc; + const format_tag_t ldsnc = mkldnn_ldsnc; + const format_tag_t ldigo = mkldnn_ldigo; + const format_tag_t ldgoi = mkldnn_ldgoi; + const format_tag_t ldgo = mkldnn_ldgo; + const format_tag_t nCdhw16c = mkldnn_nCdhw16c; + const format_tag_t nCdhw4c = mkldnn_nCdhw4c; + const format_tag_t nCdhw8c = mkldnn_nCdhw8c; + const format_tag_t nChw16c = mkldnn_nChw16c; + const format_tag_t nChw4c = mkldnn_nChw4c; + const format_tag_t nChw8c = mkldnn_nChw8c; + const format_tag_t nCw16c = mkldnn_nCw16c; + const format_tag_t nCw4c = mkldnn_nCw4c; + const format_tag_t nCw8c = mkldnn_nCw8c; + const format_tag_t IOw16o16i = mkldnn_IOw16o16i; + const format_tag_t OIw16i16o = mkldnn_OIw16i16o; + const format_tag_t OIw16o16i = mkldnn_OIw16o16i; + const format_tag_t Oiw16o = mkldnn_Oiw16o; + const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i; + const format_tag_t OIw4i4o = mkldnn_OIw4i4o; + const format_tag_t Oiw4o = mkldnn_Oiw4o; + const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i; + const format_tag_t OIw8i8o = mkldnn_OIw8i8o; + const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o; + const format_tag_t OIw8o8i = mkldnn_OIw8o8i; + const format_tag_t Owi16o = mkldnn_Owi16o; + const format_tag_t Owi4o = mkldnn_Owi4o; + const format_tag_t Owi8o = mkldnn_Owi8o; + const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i; + const format_tag_t Ohwi16o = mkldnn_Ohwi16o; + const format_tag_t Ohwi4o = mkldnn_Ohwi4o; + const format_tag_t Ohwi8o = mkldnn_Ohwi8o; + const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o; + const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i; + const format_tag_t Oihw16o = mkldnn_Oihw16o; + const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i; + const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o; + const format_tag_t Oihw4o = mkldnn_Oihw4o; + const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i; + const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o; + const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o; + const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i; + const format_tag_t Odhwi16o = mkldnn_Odhwi16o; + const format_tag_t Odhwi4o = mkldnn_Odhwi4o; + const format_tag_t Odhwi8o = mkldnn_Odhwi8o; + const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o; + const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i; + const format_tag_t Oidhw16o = mkldnn_Oidhw16o; + const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o; + const format_tag_t Oidhw4o = mkldnn_Oidhw4o; + const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i; + const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o; + const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i; + const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i; + const format_tag_t Goiw16g = mkldnn_Goiw16g; + const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o; + const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i; + const format_tag_t gOiw16o = mkldnn_gOiw16o; + const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i; + const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o; + const format_tag_t gOiw4o = mkldnn_gOiw4o; + const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i; + const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o; + const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o; + const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i; + const format_tag_t gOwi16o = mkldnn_gOwi16o; + const format_tag_t gOwi4o = mkldnn_gOwi4o; + const format_tag_t gOwi8o = mkldnn_gOwi8o; + const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i; + const format_tag_t gOhwi16o = mkldnn_gOhwi16o; + const format_tag_t gOhwi4o = mkldnn_gOhwi4o; + const format_tag_t gOhwi8o = mkldnn_gOhwi8o; + const format_tag_t Goihw16g = mkldnn_Goihw16g; + const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o; + const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i; + const format_tag_t gOihw16o = mkldnn_gOihw16o; + const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i; + const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i; + const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o; + const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i; + const format_tag_t gOihw4o = mkldnn_gOihw4o; + const format_tag_t Goihw8g = mkldnn_Goihw8g; + const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i; + const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o; + const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o; + const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i; + const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o; + const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o; + const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o; + const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o; + const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i; + const format_tag_t gOidhw16o = mkldnn_gOidhw16o; + const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o; + const format_tag_t gOidhw4o = mkldnn_gOidhw4o; + const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i; + const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o; + const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i; +} + +using memory_extra_flags_t = mkldnn_memory_extra_flags_t; +namespace memory_extra_flags { + const memory_extra_flags_t none = mkldnn_memory_extra_flag_none; + const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8; + const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust; +} + +using padding_kind_t = mkldnn_padding_kind_t; +namespace padding_kind { + const padding_kind_t padding_zero = mkldnn_padding_zero; +} + +using engine_kind_t = mkldnn_engine_kind_t; +namespace engine_kind { + const engine_kind_t any_engine = mkldnn_any_engine; + const engine_kind_t cpu = mkldnn_cpu; +} + +using primitive_kind_t = mkldnn_primitive_kind_t; +namespace primitive_kind { + const primitive_kind_t undefined = mkldnn_undefined_primitive; + const primitive_kind_t reorder = mkldnn_reorder; + const primitive_kind_t concat = mkldnn_concat; + const primitive_kind_t sum = mkldnn_sum; + const primitive_kind_t convolution = mkldnn_convolution; + const primitive_kind_t deconvolution = mkldnn_deconvolution; + const primitive_kind_t shuffle = mkldnn_shuffle; + const primitive_kind_t eltwise = mkldnn_eltwise; + const primitive_kind_t softmax = mkldnn_softmax; + const primitive_kind_t pooling = mkldnn_pooling; + const primitive_kind_t lrn = mkldnn_lrn; + const primitive_kind_t batch_normalization = mkldnn_batch_normalization; + const primitive_kind_t inner_product = mkldnn_inner_product; + const primitive_kind_t rnn = mkldnn_rnn; +} + +using query_t = mkldnn_query_t; +namespace query { + const query_t undef = mkldnn_query_undef; + + const query_t engine = mkldnn_query_engine; + const query_t primitive_kind = mkldnn_query_primitive_kind; + + const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32; + const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32; + + const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64; + const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64; + + const query_t scratchpad_engine = mkldnn_query_scratchpad_engine; + + const query_t impl_info_str = mkldnn_query_impl_info_str; + + const query_t some_d = mkldnn_query_some_d; + const query_t op_d = mkldnn_query_op_d; + const query_t convolution_d = mkldnn_query_convolution_d; + const query_t deconvolution_d = mkldnn_query_deconvolution_d; + const query_t shuffle_d = mkldnn_query_shuffle_d; + const query_t eltwise_d = mkldnn_query_eltwise_d; + const query_t softmax_d = mkldnn_query_softmax_d; + const query_t pooling_d = mkldnn_query_pooling_d; + const query_t lrn_d = mkldnn_query_lrn_d; + const query_t batch_normalization_d = mkldnn_query_batch_normalization_d; + const query_t inner_product_d = mkldnn_query_inner_product_d; + const query_t rnn_d = mkldnn_query_rnn_d; + + const query_t some_md = mkldnn_query_some_md; + const query_t src_md = mkldnn_query_src_md; + const query_t diff_src_md = mkldnn_query_diff_src_md; + const query_t weights_md = mkldnn_query_weights_md; + const query_t diff_weights_md = mkldnn_query_diff_weights_md; + const query_t dst_md = mkldnn_query_dst_md; + const query_t diff_dst_md = mkldnn_query_diff_dst_md; + + const query_t workspace_md = mkldnn_query_workspace_md; + const query_t scratchpad_md = mkldnn_query_scratchpad_md; +} + +using blocking_desc_t = mkldnn_blocking_desc_t; +using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t; +using wino_desc_t = mkldnn_wino_desc_t; +using memory_extra_desc_t = mkldnn_memory_extra_desc_t; +using memory_desc_t = mkldnn_memory_desc_t; +using convolution_desc_t = mkldnn_convolution_desc_t; +using deconvolution_desc_t = mkldnn_deconvolution_desc_t; +using shuffle_desc_t = mkldnn_shuffle_desc_t; +using pooling_desc_t = mkldnn_pooling_desc_t; +using eltwise_desc_t = mkldnn_eltwise_desc_t; +using softmax_desc_t = mkldnn_softmax_desc_t; +using lrn_desc_t = mkldnn_lrn_desc_t; +using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t; +using inner_product_desc_t = mkldnn_inner_product_desc_t; + +using rnn_direction_t = mkldnn_rnn_direction_t; +using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t; +using rnn_desc_t = mkldnn_rnn_desc_t; + +/* C op_desc_t, which eventually are just (void*) */ +using c_op_desc_t = mkldnn_op_desc_t; +using const_c_op_desc_t = const_mkldnn_op_desc_t; + +struct op_desc_t { + union { + primitive_kind_t kind; + convolution_desc_t convolution; + deconvolution_desc_t deconvolution; + shuffle_desc_t shuffle; + pooling_desc_t pooling; + eltwise_desc_t eltwise; + softmax_desc_t softmax; + lrn_desc_t lrn; + batch_normalization_desc_t batch_normalization; + inner_product_desc_t inner_product; + rnn_desc_t rnn; + }; + + op_desc_t(const primitive_kind_t &_): kind(_) {} + +# define DECL_CTOR_AND_CONVERTERS(c_type, name) \ + op_desc_t(const c_type &_): name(_) {} \ + static op_desc_t *convert_from_c(c_type *_) \ + { return reinterpret_cast(_); } \ + static const op_desc_t *convert_from_c(const c_type *_) \ + { return reinterpret_cast(_); } + + DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution); + DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle); + DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling); + DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise); + DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax); + DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn); + DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization); + DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product); + DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn); + +# undef DECL_CTOR_AND_CONVERTERS +}; + +using engine_t = mkldnn_engine; +using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator; +using primitive_desc_t = mkldnn_primitive_desc; +using primitive_attr_t = mkldnn_primitive_attr; +using post_ops_t = mkldnn_post_ops; +using memory_t = mkldnn_memory; +using primitive_t = mkldnn_primitive; + +using primitive_arg_index_t = int; + +using stream_flags_t = mkldnn_stream_flags_t; +namespace stream_flags { + const stream_flags_t default_flags = mkldnn_stream_default_flags; +} +using stream_t = mkldnn_stream; + +/* forward declaration of the internal primitive_desc types */ +struct batch_normalization_bwd_pd_t; +struct batch_normalization_fwd_pd_t; +struct batch_normalization_pd_t; +struct concat_pd_t; +struct convolution_bwd_data_pd_t; +struct convolution_bwd_weights_pd_t; +struct convolution_fwd_pd_t; +struct convolution_pd_t; +struct deconvolution_bwd_data_pd_t; +struct deconvolution_bwd_weights_pd_t; +struct deconvolution_fwd_pd_t; +struct deconvolution_pd_t; +struct eltwise_bwd_pd_t; +struct eltwise_fwd_pd_t; +struct eltwise_pd_t; +struct inner_product_bwd_data_pd_t; +struct inner_product_bwd_weights_pd_t; +struct inner_product_fwd_pd_t; +struct inner_product_pd_t; +struct lrn_bwd_pd_t; +struct lrn_fwd_pd_t; +struct lrn_pd_t; +struct pooling_bwd_pd_t; +struct pooling_fwd_pd_t; +struct pooling_pd_t; +struct reorder_pd_t; +struct rnn_bwd_pd_t; +struct rnn_fwd_pd_t; +struct rnn_pd_t; +struct shuffle_pd_t; +struct softmax_bwd_pd_t; +struct softmax_fwd_pd_t; +struct softmax_pd_t; +struct sum_pd_t; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp new file mode 100644 index 0000000000..ed4c35c6e9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "concat_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds, + const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(concat_pd, src_mds) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + int concat_dim_sz = dims[concat_dim]; + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (d == concat_dim) continue; + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + concat_dim_sz += src_mds[i].dims[concat_dim]; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != + (d == concat_dim ? concat_dim_sz : dims[d])) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.dims[concat_dim] = concat_dim_sz; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto c_pd = reinterpret_cast(concat_pd); + + for (auto c = engine->get_concat_implementation_list(); *c; ++c) { + if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds) + == success) { + (*c_pd)->init_info(); + (*c_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp new file mode 100644 index 0000000000..29311927e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp @@ -0,0 +1,211 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONCAT_PD_HPP +#define CONCAT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct concat_pd_t: public primitive_desc_t { + concat_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::concat) + , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md) + { + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + concat_pd_t(const concat_pd_t &rhs) = default; + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + int concat_dim() const { return concat_dim_; } + + const memory_desc_t *src_image_md(int index = 0) const + { return index < n_inputs() ? &src_image_mds_[index] : nullptr; } + +protected: + int n_, concat_dim_; + memory_desc_t dst_md_; + nstl::vector src_mds_; + + /* contains images of srcs in the dst memory (if possible) + * Lives here to simplify some implementations. An implementation might + * use this auxiliary array iff init() returned success */ + nstl::vector src_image_mds_; + +protected: + /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */ + status_t init() { + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + if (!i_d.is_blocking_desc() || i_d.is_additional_buffer()) + return status::unimplemented; + } + + const int ndims = dst_md_.ndims; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) return status; + src_image_mds_.push_back(src_img_d); + current_concat_dim_offset += dim; + } + + return status::success; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + const int ndims = dst_md_.ndims; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain or it is not possible to create a + * blocked format for the output, pick the format of the plain input + * - If this fails as well, use plain layout (abcd...) + */ + status_t status = status::unimplemented; + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && !src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + if (status == status::success) break; + } + } + + if (status == status::success) { + /* check if we can create a sub-memory for the dst */ + bool desired_format_ok = true; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) { + desired_format_ok = false; + break; + } + current_concat_dim_offset += dim; + } + + if (!desired_format_ok) + status = status::unimplemented; + } + + /* if no success so far, try using the format of the first plain input */ + if (status != status::success) { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + memory_desc_wrapper(src_mds_[0]).blocking_desc()); + if (status == status::success) return status; + } + } + } + + /* the last line of defense: use plain abcd... format */ + if (status != status::success) + status = memory_desc_init_by_strides(dst_md_, nullptr); + + return status; + } +}; + +#define DECLARE_CONCAT_PD_t(impl_name, ...) \ + static status_t create(concat_pd_t **concat_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, int concat_dim, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign(*concat_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_CONCAT_PD_T(impl_name, ...) \ + DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp new file mode 100644 index 0000000000..0c5c02bcd1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto cd = convolution_desc_t(); + cd.primitive_kind = primitive_kind::convolution; + cd.prop_kind = prop_kind; + cd.alg_kind = alg_kind; + + cd.diff_src_desc = cd.src_desc = zero_md(); + cd.diff_dst_desc = cd.dst_desc = zero_md(); + cd.diff_weights_desc = cd.weights_desc = zero_md(); + cd.diff_bias_desc = cd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc; + (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) = + *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(cd.strides, strides, sp_dims); + utils::array_copy(cd.padding[0], padding_l, sp_dims); + utils::array_copy(cd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(cd.dilates, dilates, sp_dims); + else + utils::array_set(cd.dilates, 0, sp_dims); + + cd.padding_kind = padding_kind; + cd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + const int bias_dim = prop_kind == backward_data + ? src_desc->dims[1] + : dst_desc->dims[1]; + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == bias_dim : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) + { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = cd.dilates[i - 2]; + int pad_l = padding_l[i - 2]; + int pad_r = padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + if (str < 1) return invalid_arguments; + consistency = consistency + && dil >= 0 + && pad_l >= 0 + && pad_r + str > 0 + && (src - ker_range + pad_l + pad_r) / str + 1 == dst; + } + if (!consistency) return invalid_arguments; + + *conv_desc = cd; + return success; +} +} +} + +status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_forward_desc_init( + convolution_desc_t *conv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + nullptr, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + dilates, padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp new file mode 100644 index 0000000000..9604e0acf5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "convolution_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_src_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_wei_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_bia_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_dst_d(const_cast(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp new file mode 100644 index 0000000000..b10c36db49 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONVOLUTION_PD_HPP +#define CONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind); + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); + +struct convolution_fwd_pd_t; + +struct convolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::convolution; + + convolution_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const convolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits::query_d: + *(const convolution_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common conv aux functions */ + + dim_t MB() const { return _src_md()->dims[0]; } + + dim_t IC() const { return _src_md()->dims[1]; } + dim_t OC() const { return _dst_md()->dims[1]; } + dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; } + + dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; } + dim_t IW() const { return _src_md()->dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; } + dim_t OW() const { return _dst_md()->dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; } + dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; } + dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return _src_md()->ndims; } + + bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); } + bool with_groups() const { return _wei_md()->ndims == ndims() + 1; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*_src_md()); + const auto d_d = memory_desc_wrapper(*_dst_md()); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + convolution_desc_t desc_; + const convolution_fwd_pd_t *hint_fwd_pd_; + + bool set_default_formats_common_template( + memory_desc_t &src_md, format_tag_t src_tag, + memory_desc_t &wei_md, format_tag_t wei_tag, + memory_desc_t &dst_md, format_tag_t dst_tag, + memory_desc_t &bia_md) { + using namespace format_tag; + +# define IS_OK(f) \ + do { if ((f) != status::success) return false; } while(0) + if (src_md.format_kind == format_kind::any + && !utils::one_of(src_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(src_md, src_tag)); + if (dst_md.format_kind == format_kind::any + && !utils::one_of(dst_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); + if (wei_md.format_kind == format_kind::any + && !utils::one_of(wei_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); + if (with_bias() && bia_md.format_kind == format_kind::any) + IS_OK(memory_desc_init_by_tag(bia_md, x)); +# undef IS_OK + + return true; + } + + bool set_default_alg_kind(alg_kind_t alg_kind) { + assert(utils::one_of(alg_kind, alg_kind::convolution_direct, + alg_kind::convolution_winograd)); + if (desc_.alg_kind == alg_kind::convolution_auto) + desc_.alg_kind = alg_kind; + return desc_.alg_kind == alg_kind; + } + + bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, + data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { + bool ok = true + && (src_dt == data_type::undef || _src_md()->data_type == src_dt) + && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt) + && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt) + && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); + if (with_bias() && bia_dt != data_type::undef) + ok = ok && _bia_md()->data_type == bia_dt; + return ok; + } + +private: + const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); } + const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); } + const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); } + const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); } +}; + +struct convolution_fwd_pd_t: public convolution_pd_t { + typedef convolution_fwd_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_fwd_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t wei_tag, format_tag_t dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); + } +}; + +struct convolution_bwd_data_pd_t: public convolution_pd_t { + typedef convolution_bwd_data_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_data_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + + virtual bool support_bias() const { return false; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t diff_src_tag, + format_tag_t wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(diff_src_md_, diff_src_tag, + weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); + } +}; + +struct convolution_bwd_weights_pd_t: public convolution_pd_t { + typedef convolution_bwd_weights_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_weights_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, + diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp new file mode 100644 index 0000000000..98063c1c37 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp @@ -0,0 +1,188 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t deconv_desc_init(deconvolution_desc_t *deconv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, deconvolution_direct, deconvolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) + return invalid_arguments; + + if (padding_r == nullptr) + padding_r = padding_l; + + auto dd = deconvolution_desc_t(); + dd.primitive_kind = primitive_kind::deconvolution; + dd.prop_kind = prop_kind; + dd.alg_kind = alg_kind; + + dd.diff_src_desc = dd.src_desc = zero_md(); + dd.diff_dst_desc = dd.dst_desc = zero_md(); + dd.diff_weights_desc = dd.weights_desc = zero_md(); + dd.diff_bias_desc = dd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias + = bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc; + (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc) + = *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc) + = *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(dd.strides, strides, sp_dims); + utils::array_copy(dd.padding[0], padding_l, sp_dims); + utils::array_copy(dd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(dd.dilates, dilates, sp_dims); + else + utils::array_set(dd.dilates, 0, sp_dims); + + dd.padding_kind = padding_kind; + dd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + bool consistency = true + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = dd.dilates[i - 2]; + int pad = padding_l[i - 2] + padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + consistency + = consistency && (dst - ker_range + pad) / str + 1 == src; + } + if (!consistency) + return invalid_arguments; + + *deconv_desc = dd; + return success; +} +} + +status_t mkldnn_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp new file mode 100644 index 0000000000..539e44bd9b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp @@ -0,0 +1,293 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DECONVOLUTION_PD_HPP +#define DECONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct deconvolution_fwd_pd_t; + +struct deconvolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::deconvolution; + + deconvolution_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const deconvolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits::query_d: + *(const deconvolution_desc_t **)result = desc(); + break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */ + + dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; } + + dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; } + dim_t G() const + { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; } + + dim_t ID() const { + return ndims() >= 5 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1]; + } + + dim_t OD() const { + return ndims() >= 5 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1]; + } + + dim_t KD() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 5 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1; + } + dim_t KH() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 4 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1; + } + dim_t KW() const { + const int w_ndims = ndims() + with_groups(); + return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1]; + } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + bool with_bias() const { + return + !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero(); + } + + bool with_groups() const + { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; } + + int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + deconvolution_desc_t desc_; + const deconvolution_fwd_pd_t *hint_fwd_pd_; +}; + +struct deconvolution_fwd_pd_t: public deconvolution_pd_t { + typedef deconvolution_fwd_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_fwd_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; +}; + +struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_data_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_data_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; +}; + +struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_weights_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_weights_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp new file mode 100644 index 0000000000..f1708fca52 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float alpha, float beta) { + bool args_ok = true + && !any_null(eltwise_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data) + && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ed = eltwise_desc_t(); + ed.primitive_kind = primitive_kind::eltwise; + ed.prop_kind = prop_kind; + ed.alg_kind = alg_kind; + + ed.data_desc = *data_desc; + ed.diff_data_desc = + (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md(); + + ed.alpha = alpha; + ed.beta = beta; + + bool consistency = true + && IMPLICATION(ed.prop_kind == backward_data, + array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims, + ed.diff_data_desc.ndims)); + if (!consistency) return invalid_arguments; + + *eltwise_desc = ed; + return success; +} +} + +status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, float alpha, float beta) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc, + nullptr, alpha, beta); +} + +status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_data_desc, + const memory_desc_t *data_desc, float alpha, float beta) { + return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc, + diff_data_desc, alpha, beta); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp new file mode 100644 index 0000000000..9fd260fcee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ELTWISE_PD_HPP +#define ELTWISE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct eltwise_fwd_pd_t; + +struct eltwise_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::eltwise; + + eltwise_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const eltwise_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::eltwise_d: + *(const eltwise_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common eltwise aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + eltwise_desc_t desc_; + const eltwise_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct eltwise_fwd_pd_t: public eltwise_pd_t { + typedef eltwise_fwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const + { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); } +}; + +struct eltwise_bwd_pd_t: public eltwise_pd_t { + typedef eltwise_bwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_bwd_pd_t(engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const { return true; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp new file mode 100644 index 0000000000..3b3e25456d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include "engine.hpp" +#include "nstl.hpp" + +#include "c_types_map.hpp" +#include "../cpu/cpu_engine.hpp" + +namespace mkldnn { +namespace impl { + +engine_factory_t *engine_factories[] = { + &cpu::engine_factory, + nullptr, +}; + +static inline engine_factory_t *get_engine_factory(engine_kind_t kind) { + for (engine_factory_t **ef = engine_factories; *ef; ef++) + if ((*ef)->kind() == kind) + return *ef; + return nullptr; +} + +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +size_t mkldnn_engine_get_count(engine_kind_t kind) { + engine_factory_t *ef = get_engine_factory(kind); + return ef != nullptr ? ef->count() : 0; +} + +status_t mkldnn_engine_create(engine_t **engine, + engine_kind_t kind, size_t index) { + if (engine == nullptr) + return invalid_arguments; + + engine_factory_t *ef = get_engine_factory(kind); + if (ef == nullptr || index >= ef->count()) + return invalid_arguments; + + return ef->engine_create(engine, index); +} + +status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) { + if (engine == nullptr) + return invalid_arguments; + *kind = engine->kind(); + return success; +} + +status_t mkldnn_engine_destroy(engine_t *engine) { + /* TODO: engine->dec_ref_count(); */ + delete engine; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp new file mode 100644 index 0000000000..8ac8a29de5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ENGINE_HPP +#define ENGINE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive.hpp" +#include "utils.hpp" + +/** \brief An abstraction of an execution unit with shared resources + * + * Responsibilities: + * - Provide engine specific memory allocation + * - Provide engine specific primitive_desc_t creators + */ +struct mkldnn_engine: public mkldnn::impl::c_compatible { + mkldnn_engine(mkldnn::impl::engine_kind_t kind) + : kind_(kind) + {} + virtual ~mkldnn_engine() {} + + /** get kind of the current engine */ + virtual mkldnn::impl::engine_kind_t kind() const { return kind_; } + + /** allocate memory */ + virtual mkldnn::impl::status_t memory_create( + mkldnn::impl::memory_t **memory, + const mkldnn::impl::memory_desc_t *md, + void *handle) = 0; + + /** implementation section (typedefs) */ + + // TODO: remove engine? + typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)( + mkldnn::impl::reorder_pd_t **reorder_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *src_engine, + const mkldnn::impl::memory_desc_t *src_md, + mkldnn::impl::engine_t *dst_engine, + const mkldnn::impl::memory_desc_t *dst_md); + + typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)( + mkldnn::impl::concat_pd_t **concat_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, int concat_dim, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)( + mkldnn::impl::sum_pd_t **sum_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, const float *scales, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*primitive_desc_create_f)( + mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *); + + /* implementation section */ + + /** return the list of reorder implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const = 0; + + /** return the list of concat implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const = 0; + + /** return the list of sum implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const = 0; + + /** return the list of implementations. engine guarantees to return a + * NULL-terminated list */ + virtual const primitive_desc_create_f* get_implementation_list() const = 0; + +protected: + mkldnn::impl::engine_kind_t kind_; +}; + +namespace mkldnn { +namespace impl { + +struct engine_factory_t: public c_compatible { + virtual size_t count() const = 0; + virtual engine_kind_t kind() const = 0; + virtual status_t engine_create(engine_t **engine, size_t index) const = 0; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp new file mode 100644 index 0000000000..5a9f58cb1e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { + bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc); + if (!args_ok) return invalid_arguments; + + auto id = inner_product_desc_t(); + id.primitive_kind = primitive_kind::inner_product; + id.prop_kind = prop_kind; + + id.diff_src_desc = id.src_desc = zero_md(); + id.diff_dst_desc = id.dst_desc = zero_md(); + id.diff_weights_desc = id.weights_desc = zero_md(); + id.diff_bias_desc = id.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + + (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc; + (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) = + *bias_desc; + + id.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && one_of(src_desc->ndims, 2, 3, 4, 5) + && dst_desc->ndims == 2 + && weights_desc->ndims == src_desc->ndims + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && array_cmp(&src_desc->dims[1], &weights_desc->dims[1], + src_desc->ndims - 1) + && dst_desc->dims[1] == weights_desc->dims[0]; + if (!consistency) return invalid_arguments; + + *ip_desc = id; + return success; +} +} + +status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc, + prop_kind_t prop_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc, + dst_desc); +} + +status_t mkldnn_inner_product_backward_data_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc) +{ + return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc, + nullptr, diff_dst_desc); +} + +status_t mkldnn_inner_product_backward_weights_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *src_desc, + const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc) { + return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc, + diff_bias_desc, diff_dst_desc); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp new file mode 100644 index 0000000000..091cf0f5d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "inner_product_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_src_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_wei_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_bia_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_dst_d(const_cast(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp new file mode 100644 index 0000000000..c426de632c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp @@ -0,0 +1,321 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef INNER_PRODUCT_PD_HPP +#define INNER_PRODUCT_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc); + +struct inner_product_fwd_pd_t; + +struct inner_product_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::inner_product; + + inner_product_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const inner_product_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::inner_product_d: + *(const inner_product_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common inner_product aux functions */ + + dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; } + dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; } + + dim_t ID() const { + return ndims() >= 5 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return ndims() >= 3 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t OD() const { + return ndims() >= 5 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return ndims() >= 3 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t KD() const { + return ndims() >= 5 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t KH() const { + return ndims() >= 4 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t KW() const { + return ndims() >= 3 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t IC_total() const { + return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1], + ndims() - 1); + } + + dim_t IC_total_padded() const { + auto src_d = desc()->prop_kind == prop_kind::backward_data + ? memory_desc_wrapper(diff_src_md()) + : memory_desc_wrapper(src_md()); + assert(src_d.is_blocking_desc()); + if (!src_d.is_blocking_desc()) return -1; + return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); + } + + int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; } + + bool with_bias() const + { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + inner_product_desc_t desc_; + const inner_product_fwd_pd_t *hint_fwd_pd_; + + status_t template_set_default_params(memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t *bias_md) { + using namespace format_tag; + if (src_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, + utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw))); + } + if (dst_md.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md, nc)); + if (weights_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, + utils::pick(ndims() - 2, oi, oiw, oihw, oidhw))); + } + if (bias_md && bias_md->format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(*bias_md, x)); + return status::success; + } +}; + +struct inner_product_fwd_pd_t: public inner_product_pd_t { + typedef inner_product_fwd_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_fwd_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, weights_md_, dst_md_, + &bias_md_); + } +}; + +struct inner_product_bwd_data_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_data_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_data_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(diff_src_md_, weights_md_, + diff_dst_md_, nullptr); + } +}; + +struct inner_product_bwd_weights_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_weights_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_weights_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, diff_weights_md_, + diff_dst_md_, &diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp new file mode 100644 index 0000000000..fcf18b556f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t lrn_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc, + dim_t local_size, float alpha, float beta, float k) { + bool args_ok = true + && !any_null(lrn_desc, data_desc) + && one_of(alg_kind, lrn_within_channel, lrn_across_channels) + && one_of(prop_kind, forward_training, forward_inference, backward_data) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ld = lrn_desc_t(); + ld.primitive_kind = primitive_kind::lrn; + ld.prop_kind = prop_kind; + ld.alg_kind = alg_kind; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + ld.data_desc = *data_desc; + if (!is_fwd) + ld.diff_data_desc = *diff_data_desc; + else + ld.diff_data_desc = zero_md(); + ld.local_size = local_size; + ld.lrn_alpha = alpha; + ld.lrn_beta = beta; + ld.lrn_k = k; + + bool consistency = true + && ld.data_desc.ndims == 4; + if (ld.prop_kind == backward_data) + consistency = consistency + && ld.diff_data_desc.ndims == 4 + && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4); + if (!consistency) return invalid_arguments; + + *lrn_desc = ld; + return success; +} +} + +status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, dim_t local_size, float alpha, + float beta, float k) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr, + local_size, alpha, beta, k); +} + +status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, dim_t local_size, float alpha, + float beta, float k) { + return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc, + diff_data_desc, local_size, alpha, beta, k); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp new file mode 100644 index 0000000000..90886e9656 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef LRN_PD_HPP +#define LRN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct lrn_fwd_pd_t; + +struct lrn_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::lrn; + + lrn_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , ws_md_() + {} + + const lrn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::lrn_d: + *(const lrn_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common lrn aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + lrn_desc_t desc_; + const lrn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t ws_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct lrn_fwd_pd_t: public lrn_pd_t { + typedef lrn_fwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_fwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct lrn_bwd_pd_t: public lrn_pd_t { + typedef lrn_bwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_bwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp new file mode 100644 index 0000000000..3fddc0bd45 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MATH_UTILS_HPP +#define MATH_UTILS_HPP + +#include +#include + +#include "utils.hpp" +#include "nstl.hpp" +#include "mkldnn_traits.hpp" + +#if defined(MKLDNN_X86_64) +#include "immintrin.h" +#endif + +namespace mkldnn { +namespace impl { +namespace math { + +/** rounds @p f to an integer according to the mxcsr register */ +inline int mxcsr_round(float f) { +#if defined(MKLDNN_X86_64) + return _mm_cvtss_si32(_mm_load_ss(&f)); +#else + return (int)nearbyintf(f); // optimism +#endif +} + +template +inline typename utils::enable_if::value, + typename utils::remove_reference::type>::type +saturate(const acc_t &x) { + return (typename utils::remove_reference::type)x; +} + +template +inline typename utils::enable_if::value, + typename utils::remove_reference::type>::type +saturate(const acc_t &x) { + acc_t v = x; + if (v < (acc_t)nstl::numeric_limits::lowest()) + v = (acc_t)nstl::numeric_limits::lowest(); + if (v > (acc_t)nstl::numeric_limits::max()) + v = (acc_t)nstl::numeric_limits::max(); + return (typename utils::remove_reference::type)v; +} + +template +double saturate(const double &x) { + double v = x; + if (v < (double)nstl::numeric_limits::lowest()) + v = (double)nstl::numeric_limits::lowest(); + if (v > (double)nstl::numeric_limits::max()) + v = (double)nstl::numeric_limits::max(); + return v; +} + +template <> inline int8_t saturate(const uint8_t &x) { + return x <= 127u ? x : 127; +} + +template <> inline uint8_t saturate(const int8_t &x) { + return x >= 0 ? x : 0; +} + +template +typename utils::enable_if::value, out_t>::type +out_round(float v) { return (out_t)mxcsr_round(v); } + +template +typename utils::enable_if::value, out_t>::type +out_round(double v) { return (out_t)mxcsr_round((float)v); } + +template +typename utils::enable_if::value, out_t>::type +out_round(float v) { return v; } + +inline int gcd(int a, int b) { + a = impl::nstl::abs(a); + b = impl::nstl::abs(b); + if (a < b) { int x = a; a = b; b = x; } + + if (b == 0) return a; + + int r; + while ((r = a % b) != 0) { a = b; b = r; } + + return b; +} + +template +inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } + +/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ +inline int ilog2q(size_t v) { + if (v == 0) + return -1; + + int p = 0; +# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) + CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); +# undef CP + return p; +} + +template ::type> +inline U one_m_square(T x) { + return (U)(1 - x) * (1 + x); +} + +template ::type> +inline U x_m_square(T x) { + return (U)(1 - x) * x; +} + +/* activation */ +template ::type> +inline U relu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(s * alpha); +} +template ::type> +inline U relu_bwd(T dd, T s, A alpha) { + return s > 0 ? dd : (U)(dd * alpha); +} + +template ::type> +inline U tanh_fwd(T s) { + const float e = tanhf((float) s); + return (U)e; +} + +template ::type> +inline U tanh_bwd(T dd, T s) { + const float e = tanh_fwd((float) s); + return (U)(dd * (1 - e) * (1 + e)); +} + +template ::type> +inline U elu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); +} +template ::type> + inline U elu_bwd(T dd, T s, A alpha) { + return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); +} + +template ::type> +inline U square_fwd(T s) { + return s * s; +} + +template ::type> +inline U square_bwd(T dd, T s) { + return dd * 2 * s; +} + +template ::type> +inline U abs_fwd(T s) { + return s > 0 ? s : -s; +} + +template ::type> +inline U abs_bwd(T dd, T s) { + return s > 0 ? dd : s < 0 ? -dd : 0; +} + +template ::type> +inline U sqrt_fwd(T s) { + return s > 0 ? (U)(::sqrtf((float)(s))) : 0; +} + +template ::type> +inline U sqrt_bwd(T dd, T s) { + return s > 0 + ? (U)(dd / (2 * ::sqrtf((float)(s)))) + : 0; +} + +template ::type> +inline U linear_fwd(T s, A alpha, A beta) { + return (U)(alpha * s + beta); +} + +template ::type> +inline U linear_bwd(T dd, T s, A alpha, A beta) { + (void) s; + (void) beta; + return (U)(dd * alpha); +} + +template ::type> +inline U bounded_relu_fwd(T s, A alpha) { + s = s > 0 ? s : 0; + return s > alpha ? (U)(alpha) : s; +} + +template ::type> +inline U bounded_relu_bwd(T dd, T s, A alpha) { + return dd * (0 < s && s < alpha ? 1 : 0); +} + +template ::type> +inline U soft_relu_fwd(T s) { + float max_logf = 8.872284e+01; //::logf(FLT_MAX) + return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; +} + +template ::type> +inline U soft_relu_bwd(T dd, T s) { + return (U)(dd / (1 + ::expf((float)(-s)))); +} + +template ::type> +inline U logistic_fwd(T s) { + U v = (U)(::expf((float) -s)); + return 1 / (1 + v); +} + +template ::type> +inline U logistic_bwd(T dd, T s) { + U v = logistic_fwd(s); + return dd * v * (1 - v); +} + +inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { + using namespace alg_kind; + using namespace utils; + const bool preserves_zero = true + && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); + return preserves_zero; +} + +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) +{ + if (!bias) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits
::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp new file mode 100644 index 0000000000..cea849c96e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::data_type; + +namespace { +bool memory_desc_sanity_check(int ndims,const dims_t dims, + data_type_t data_type, format_kind_t format_kind) { + if (ndims == 0) return true; + + bool ok = true + && dims != nullptr + && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS + && one_of(data_type, f32, s32, s8, u8) + && format_kind != format_kind::undef; + if (!ok) return false; + for (int d = 0; d < ndims; ++d) + if (dims[d] < 0) return false; + + return true; +} + +bool memory_desc_sanity_check(const memory_desc_t *md) { + if (md == nullptr) return false; + return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, + format_kind::any); +} +} + +status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, + const dims_t dims, data_type_t data_type, format_tag_t tag) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0 || tag == format_tag::undef) { + *memory_desc = types::zero_md(); + return success; + } + + format_kind_t format_kind = types::format_tag_to_kind(tag); + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind; + + status_t status = success; + if (tag == format_tag::undef) { + status = invalid_arguments; + } else if (tag == format_tag::any) { + // nop + } else if (format_kind == format_kind::blocked) { + status = memory_desc_wrapper::compute_blocking(md, tag); + } else { + assert(!"unreachable"); + status = invalid_arguments; + } + + if (status == success) + *memory_desc = md; + + return status; +} + +status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, + int ndims, const dims_t dims, data_type_t data_type, + const dims_t strides) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0) { + *memory_desc = types::zero_md(); + return success; + } + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind::blocked; + + dims_t default_strides = {0}; + if (strides == nullptr) { + default_strides[md.ndims - 1] = 1; + for (int d = md.ndims - 2; d >= 0; --d) + default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; + strides = default_strides; + } else { + /* TODO: add sanity check for the provided strides */ + } + + array_copy(md.format_desc.blocking.strides, strides, md.ndims); + + *memory_desc = md; + + return status::success; +} + +status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, + const memory_desc_t *parent_md, const dims_t dims, + const dims_t offsets) { + if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) + return invalid_arguments; + + const memory_desc_wrapper src_d(parent_md); + + for (int d = 0; d < src_d.ndims(); ++d) { + if (dims[d] < 0 || offsets[d] < 0 + || (offsets[d] + dims[d] > src_d.dims()[d])) + return invalid_arguments; + } + + if (src_d.format_kind() != format_kind::blocked) + return unimplemented; + + dims_t blocks; + src_d.compute_blocks(blocks); + + memory_desc_t dst_d = *parent_md; + auto &dst_d_blk = dst_d.format_desc.blocking; + + /* TODO: put this into memory_desc_wrapper */ + for (int d = 0; d < src_d.ndims(); ++d) { + /* very limited functionality for now */ + const bool ok = true + && offsets[d] % blocks[d] == 0 /* [r1] */ + && src_d.padded_offsets()[d] == 0 + && (false + || dims[d] % blocks[d] == 0 + || dims[d] < blocks[d]); + if (!ok) + return unimplemented; + + const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; + + dst_d.dims[d] = dims[d]; + dst_d.padded_dims[d] = is_right_border + ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; + dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; + dst_d.offset0 += /* [r1] */ + offsets[d] / blocks[d] * dst_d_blk.strides[d]; + } + + *md = dst_d; + + return success; +} + +int mkldnn_memory_desc_equal(const memory_desc_t *lhs, + const memory_desc_t *rhs) { + if (lhs == rhs) return 1; + if (any_null(lhs, rhs)) return 0; + return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); +} + +size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { + if (md == nullptr) return 0; + return memory_desc_wrapper(*md).size(); +} + +status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, + engine_t *engine, void *handle) { + if (any_null(memory, engine)) return invalid_arguments; + memory_desc_t z_md = types::zero_md(); + return engine->memory_create(memory, md ? md : &z_md, handle); +} + +status_t mkldnn_memory_get_memory_desc(const memory_t *memory, + const memory_desc_t **md) { + if (any_null(memory, md)) return invalid_arguments; + *md = memory->md(); + return success; +} + +status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { + if (any_null(memory, engine)) return invalid_arguments; + *engine = memory->engine(); + return success; +} + +status_t mkldnn_memory_get_data_handle(const memory_t *memory, + void **handle) { + if (any_null(handle)) + return invalid_arguments; + if (memory == nullptr) { + *handle = nullptr; + return success; + } + return memory->get_data_handle(handle); +} + +status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { + if (any_null(memory)) return invalid_arguments; + return memory->set_data_handle(handle); +} + +status_t mkldnn_memory_destroy(memory_t *memory) { + delete memory; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp new file mode 100644 index 0000000000..03dfee01ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_HPP +#define MEMORY_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" + +struct mkldnn_memory: public mkldnn::impl::c_compatible { + mkldnn_memory(mkldnn::impl::engine_t *engine, + const mkldnn::impl::memory_desc_t *md) + : engine_(engine), md_(*md) {} + virtual ~mkldnn_memory() {} + + /** allocates/initializes memory */ + virtual mkldnn::impl::status_t init() = 0; + + /** returns memory's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + /** returns memory's description */ + const mkldnn::impl::memory_desc_t *md() const { return &md_; } + + /** returns data handle */ + virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0; + + /** sets data handle */ + virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0; + + /** zeros padding */ + virtual mkldnn::impl::status_t zero_pad() const + { return mkldnn::impl::status::success; } + +protected: + mkldnn::impl::engine_t *engine_; + const mkldnn::impl::memory_desc_t md_; + +private: + mkldnn_memory() = delete; + mkldnn_memory(const mkldnn_memory &) = delete; + mkldnn_memory(mkldnn_memory &&) = delete; + mkldnn_memory &operator=(const mkldnn_memory &) = delete; + mkldnn_memory &operator=(mkldnn_memory &&) = delete; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp new file mode 100644 index 0000000000..8a99be33f3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp @@ -0,0 +1,212 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t fill_blocked(memory_desc_t &md, + std::initializer_list perm, + std::initializer_list inner_blks, + std::initializer_list inner_idxs) { + const bool ok = true + && perm.size() == (size_t)md.ndims + && inner_blks.size() == inner_idxs.size(); + if (!ok) return status::invalid_arguments; + + md.offset0 = 0; + + blocking_desc_t &blk = md.format_desc.blocking; + + dim_t block_size = 1; + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + + blk.inner_nblks = (int)inner_blks.size(); + + int iblk = 0; + for (const auto &b: inner_idxs) + blk.inner_idxs[iblk++] = b; + + iblk = 0; + for (const auto &b: inner_blks) { + int dim = blk.inner_idxs[iblk]; + block_size *= b; + blocks[dim] *= b; + blk.inner_blks[iblk++] = b; + } + + utils::array_set(md.padded_offsets, 0, md.ndims); + for (int d = 0; d < md.ndims; ++d) + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + + dim_t stride = block_size; + // if only we use C++14, the initializer_list would have rbegin()/rend()... + for (int d = 0; d < md.ndims; ++d) + stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; + + for (const auto &d: perm) { + if (md.padded_dims[d] == 0) { + blk.strides[d] = 1; + continue; + } + stride /= md.padded_dims[d] / blocks[d]; + blk.strides[d] = stride; + } + + assert(stride == block_size); + + return status::success; +} + +status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag) +{ + using namespace format_tag; + + if (memory_desc.ndims == 0) return status::invalid_arguments; + +# define C(tag, ... /* perm, inner_blks, inner_idxs */) \ + case tag: return fill_blocked(memory_desc, __VA_ARGS__) + + switch (tag) { + C(a, {0}, {}, {}); + C(ab, {0, 1}, {}, {}); + C(abc, {0, 1, 2}, {}, {}); + C(abcd, {0, 1, 2, 3}, {}, {}); + C(abcde, {0, 1, 2, 3, 4}, {}, {}); + C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); + C(abdec, {0, 1, 3, 4, 2}, {}, {}); + C(acb, {0, 2, 1}, {}, {}); + C(acbde, {0, 2, 1, 3, 4}, {}, {}); + C(acdb, {0, 2, 3, 1}, {}, {}); + C(acdeb, {0, 2, 3, 4, 1}, {}, {}); + C(ba, {1, 0}, {}, {}); + C(bac, {1, 0, 2}, {}, {}); + C(bacd, {1, 0, 2, 3}, {}, {}); + C(bcda, {1, 2, 3, 0}, {}, {}); + C(cba, {2, 1, 0}, {}, {}); + C(cdba, {2, 3, 1, 0}, {}, {}); + C(cdeba, {2, 3, 4, 1, 0}, {}, {}); + C(decab, {3, 4, 2, 0, 1}, {}, {}); + + C(Abc4a, {0, 1, 2}, {4}, {0}); + C(aBc4b, {0, 1, 2}, {4}, {1}); + C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); + C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); + C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); + C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); + C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); + C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); + C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); + C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); + C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); + C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); + C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); + C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); + C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); + C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); + C(Acb4a, {0, 2, 1}, {4}, {0}); + C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); + C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); + + C(Abc16a, {0, 1, 2}, {16}, {0}); + C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); + C(aBc16b, {0, 1, 2}, {16}, {1}); + C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); + C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); + C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); + C(aBc8b, {0, 1, 2}, {8}, {1}); + C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); + C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); + C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); + C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); + C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); + C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); + C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); + C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); + C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); + C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); + C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); + C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); + C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); + C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); + C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); + C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); + C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); + C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); + C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); + C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); + C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); + C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); + C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); + C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); + C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); + C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); + C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); + C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); + C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); + C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); + C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); + C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); + C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); + C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); + C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); + C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); + C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); + C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); + C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); + C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); + C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); + C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); + C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); + C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); + C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); + C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); + C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); + C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); + C(Acb16a, {0, 2, 1}, {16}, {0}); + C(Acb8a, {0, 2, 1}, {8}, {0}); + C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); + C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); + C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); + C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); + C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); + C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); + C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); + C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); + default: break; + } + +#undef C + + return status::invalid_arguments; +} + +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp new file mode 100644 index 0000000000..1758f9078a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_DESC_WRAPPER_HPP +#define MEMORY_DESC_WRAPPER_HPP + +#include + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +/** thin wrapper class over \struct memory_desc_t which allows easy + * manipulations with underlying C structure, which is taken by reference */ +struct memory_desc_wrapper: public c_compatible { + const memory_desc_t *md_; + + /** constructor which takes a reference to a constant underlying C memory + * descriptor \param md */ + memory_desc_wrapper(const memory_desc_t *md): md_(md) {} + memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {} + + /* implementing attributes */ + int ndims() const { return md_->ndims; } + const dims_t &dims() const { return md_->dims; } + data_type_t data_type() const { return md_->data_type; } + + const dims_t &padded_dims() const { return md_->padded_dims; } + const dims_t &padded_offsets() const { return md_->padded_offsets; } + dim_t offset0() const { return md_->offset0; } + + format_kind_t format_kind() const { return md_->format_kind; } + + bool is_blocking_desc() const + { return format_kind() == format_kind::blocked; } + bool is_wino_desc() const + { return format_kind() == format_kind::wino; } + bool is_rnn_packed_desc() const + { return format_kind() == format_kind::rnn_packed; } + + const blocking_desc_t &blocking_desc() const { + assert(is_blocking_desc()); + return md_->format_desc.blocking; + } + const wino_desc_t &wino_desc() const { + assert(is_wino_desc()); + return md_->format_desc.wino_desc; + } + const rnn_packed_desc_t &rnn_packed_desc() const { + assert(is_rnn_packed_desc()); + return md_->format_desc.rnn_packed_desc; + } + + const memory_extra_desc_t &extra() const { return md_->extra; } + + /* some useful function */ + + /** returns the number of elements including padding if \param with_padding + * is true, and the number of data elements otherwise */ + dim_t nelems(bool with_padding = false) const { + if (is_zero()) return 0; + return utils::array_product( + with_padding ? padded_dims() : dims(), ndims()); + } + + /** returns true if memory descriptor is zero */ + bool is_zero() const { return ndims() == 0; } + + /** returns true if memory descriptor contains zero as one of its dim */ + bool has_zero_dim() const { return nelems() == 0; } + + /** return the size of data type (a shortcut) */ + size_t data_type_size() const + { return types::data_type_size(data_type()); } + + /** return the size of data type of additional buffer */ + size_t additional_buffer_data_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) + return sizeof(int32_t); + return 0; + } + + /** return true if memory format has additional buffer */ + bool is_additional_buffer() const { + return (extra().flags & memory_extra_flags::compensation_conv_s8s8); + } + + /** returns the size of additional buffer */ + size_t additional_buffer_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) { + int cmask = extra().compensation_mask; + assert(cmask == 1 || cmask == 3); + dim_t prod = 1; + for (int d = 0; d < ndims(); ++d) + if (cmask & (1<(max_size, + padded_dims()[d] / blocks[d] * bd.strides[d]); + + if (max_size == 1 && bd.inner_nblks != 0) { + max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); + } + + return max_size * data_type_size() + additional_buffer_size(); + } + } + + /** returns true if data is dense in memory */ + bool is_dense(bool with_padding = false) const { + if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + return nelems(with_padding) * data_type_size() == size(); + } + + /** returns true if memory desc is fully defined */ + bool is_defined() const { return format_kind() != format_kind::any; } + + /** returns true if the only (potentially) padded dim is \param dim */ + bool only_padded_dim(int dim) const { + for (int d = 0; d < ndims(); ++d) + if (d != dim && dims()[d] != padded_dims()[d]) + return false; + return true; + } + + /** returns true if memory desc has blocked layout and block dims are 1s */ + bool is_plain() const { + if (!is_blocking_desc()) return false; + return blocking_desc().inner_nblks == 0; + } + + /** returns overall block sizes */ + void compute_blocks(dims_t blocks) const { + if (!is_blocking_desc()) { + utils::array_set(blocks, 0, ndims()); + return; + } + + utils::array_set(blocks, 1, ndims()); + + const auto &bd = blocking_desc(); + for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) + blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; + } + + /* comparison section */ + + bool operator==(const memory_desc_wrapper &rhs) const + { return *this->md_ == *rhs.md_; } + bool operator!=(const memory_desc_wrapper &rhs) const + { return !operator==(rhs); } + bool operator==(const memory_desc_t &rhs) const + { return operator==(memory_desc_wrapper(rhs)); } + bool operator!=(const memory_desc_t &rhs) const + { return !operator==(rhs); } + + /** returns true if data (w/o padding if with_padding == false and w/ + * padding otherwise) have the same physical structure, i.e. dimensions, + * strides, and blocked structure. Depending on with_data_type flag + * data_type is taken or not taken into account. dim_start allows to check + * similarity for the logical part of data [dim_start .. ndims()]. + * CAUTION: format kind any and undef are not similar to whatever, hence the + * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ + /* TODO: revise */ + bool similar_to(const memory_desc_wrapper &rhs, + bool with_padding = true, bool with_data_type = true, + int dim_start = 0) const; + + /** returns true if one memory can be reordered to another */ + bool consistent_with(const memory_desc_wrapper &rhs) const; + + /** returns true if the memory desc corresponds to the given format tag and + * strides. + * @sa memory_desc_matches_tag */ + bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { + return memory_desc_matches_tag(*md_, tag, strides); + } + + /** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ + template + format_tag_t matches_one_of_tag(Tags ...tags) const { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(*md_, tag)) + return tag; + } + return format_tag::undef; + } + + /* offset section */ + + /** returns physical offset by logical one. logical offset is represented by + * an array \param pos. if \param is_pos_padded is true \param pos + * represents the position in already padded area */ + dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + const blocking_desc_t &blk = blocking_desc(); + + dims_t pos_copy = {0}; + for (int d = 0; d < ndims(); ++d) + pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]); + + dim_t phys_offset = offset0(); + + if (blk.inner_nblks > 0) { + dim_t blk_stride = 1; + for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) { + const int d = blk.inner_idxs[iblk]; + const dim_t p = pos_copy[d] % blk.inner_blks[iblk]; + + phys_offset += p * blk_stride; + + pos_copy[d] /= blk.inner_blks[iblk]; + + blk_stride *= blk.inner_blks[iblk]; + } + } + + for (int d = 0; d < ndims(); ++d) { + const dim_t p = pos_copy[d]; + phys_offset += p * blk.strides[d]; + } + + return phys_offset; + } + + /** returns physical offset by logical one. logical offset is represented by + * a scalar \param l_offset. if \param is_pos_padded is true, \param + * l_offset represents logical offset in already padded area */ + dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + dims_t pos; + for (int rd = 0; rd < ndims(); ++rd) { + const int d = ndims() - 1 - rd; + const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d]; + pos[d] = l_offset % cur_dim; + l_offset /= cur_dim; + } + return off_v(pos, is_pos_padded); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) */ + template + dim_t off(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, false); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) in already + * padded area */ + template + dim_t off_padding(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, true); + } + + /** returns physical offset by logical one. Logical offset is represented by + * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a + * user responsibility to adjust the result to get offset within blocks */ + template + dim_t blk_off(Args... args) const { + return _blk_off(args...); + } + + template + dim_t blk_off(T xn, Args... args) const { + return skip_first + ? blk_off(args...) + : blk_off(xn, args...); + } + + /* static functions section */ + /* TODO: replace with non-static, once md_ becomes non-const ref */ + + static status_t compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag); + +private: + /* TODO: put logical_offset in utils */ + template + dim_t logical_offset(T x0) const { return x0; } + + template + dim_t logical_offset(T xn, Args... args) const { + const size_t n_args = sizeof...(args); + return xn * utils::array_product( + &dims()[ndims() - n_args]) + logical_offset(args...); + } + + template + dim_t _blk_off() const { return offset0(); } + + template + dim_t _blk_off(T xc, Args ...args) const { + assert(is_blocking_desc()); + constexpr int dc = ORIG_LEN - sizeof...(args) - 1; + return xc * blocking_desc().strides[dc] + + _blk_off(args...); + } +}; + +inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, + bool with_padding, bool with_data_type, int dim_start) const { + using namespace utils; + + if (one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + if (is_wino_desc() || is_rnn_packed_desc()) + return false; + + const int ds = dim_start; + const auto &blk = blocking_desc(); + const auto &r_blk = rhs.blocking_desc(); + + return ndims() == rhs.ndims() + && dim_start <= ndims() /* guard */ + && format_kind() == rhs.format_kind() + && IMPLICATION(with_data_type, data_type() == rhs.data_type()) + && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) + && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) + && blk.inner_nblks == r_blk.inner_nblks + && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) + && IMPLICATION(with_padding, true + && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds, + ndims() - ds) + && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds, + ndims() - ds)); +} + +inline bool memory_desc_wrapper::consistent_with( + const memory_desc_wrapper &rhs) const { + if (ndims() == rhs.ndims()) { + for (int d = 0; d < ndims(); ++d) { + if (dims()[d] != rhs.dims()[d]) return false; + } + return true; + } else { + /* TODO: revise. + * is the following possible? + * [1, a, b] <--reorder--> [a, b] + * [a, 1, b] <--reorder--> [a, b] + * not, at least for now */ + return false; + } +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp new file mode 100644 index 0000000000..ec077b308c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp @@ -0,0 +1,295 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_TRACKING_HPP +#define MEMORY_TRACKING_HPP + +#include +#include + +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace memory_tracking { + +/* Memory tracking capabilities + * + * The main purpose of this header file is to provide uniform way to register + * required memory for a scratchpad at a primitive descriptor creation time + * and then easily access it having only the base address of the scratchpad. + * + * Primitives might contain multiple disjoint parts that require temporary + * buffers (known as scratchpad) during their execution. A primitive descriptor + * should summarize all the needs into one single number -- the buffer size + * that would be requested from a user. At execution time, the corresponding + * primitive will receive a base pointer to a scratchpad. It then needs to + * provide each part of algorithm the corresponding piece of memory. Three main + * challenges here are: + * 1. Track correct offset (from the base scratchpad address) for each piece + * 2. Algorithm might require that different memory pieces to be aligned, so + * the scratchpad size is no more just a sum of size of the corresponding + * subparts. + * 3. While a primitive is responsible for its scratchpad, the implementation + * might use some other basic blocks (e.g. cpu_reducer) that also require + * scratchpad memory. So there should be a simple way of passing the + * information back and force between the main algorithm (a primitive) and + * auxiliary stuff that lives completely separately from it (e.g. reducer). + * + * To address these challenges this header file provides 3 structures: + * 1. registry_t -- the class the stores the information about requested + * memory. The information includes required size and desired + * alignment for each piece. This class is also responsible + * for computing the right offset to a given piece using the + * base pointer. + * This class is basically a ledger with all entries. + * Lives in primitive descriptors. + * + * 2. registrar_t -- the interface to a registry_t to book memory. Used at + * primitive descriptor creation time only. Contains a + * reference to the corresponding *mutable* registry. + * Always modifiable. + * Allows chaining (using prefixes). + * + * 3. grantor_t -- the interface to a registry_t to access memory. Used at + * primitive execution time only. Contains a reference to + * the corresponding *constant* registry and base pointer. + * Always constant. + * Allows chaining (using prefixes). + * + * Both registrar_t and grantor_t allow chaining with extra prefix provided. + * The feature is useful when a primitive offload a part of computations to + * some other primitives which require their own scratchpad space + * (e.g. reducer). Prefixes are used to avoid key collision in cases when + * multiple sub-primitive (e.g. multiple reducers) are used. + * + * A short example below demonstrates how to use aforementioned classes. In it + * the main primitive is convolution that uses scratchpad for keeping padded + * bias. It also needs a reducer, that needs its own space as well. + * + * ``` c++ + * struct reducer_t { + * static void init(registrar_t &scratchpad) { + * // preserve space for the reduction (one page aligned) + * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096); + * } + * + * void exec(const grantor_t &scratchpad) { + * // get the pointer to preserved space. scratchpad came from + * // upper primitive (convolution in this example) + * auto space = scratchpad.get(key_reducer_space); + * + * space[:] += ...; + * } + * }; + * + * struct conv_t { + * struct pd_t { + * void init() { + * registrar_t scratchpad(scratchpad_registry_); + * + * // preserve a space for padded bias (using default alignment) + * scratchpad.book(key_conv_padded_bias, 128); + * + * // create a proxy registrar for the reducer All entries made + * // by reducer would live in convolution's registry, but would + * // have their own `prefix`, so no interference with conv's + * // buffers. + * registrar_t reducer_scratchpad(scratchpad, prefix_reducer); + * + * reducer_t::init(reducer_scratchpad); + * } + * + * registry_t scratchpad_registry_; + * } + * + * void exec() { + * // get the base pointer to a scratchpad memory from a user + * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD); + * + * // create a grantor to the scratchpad (and provide the base + * // pointer). + * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr); + * + * // access the padded_bias (need only key name and the grantor) + * auto padded_bias = scratchpad.get(key_conv_padded_bias); + * + * // to give the `right` grantor to reducer we need to add the + * // corresponding prefix, so that reducer would be able to access + * // its keys. The call is very similar to the one in pd_t::init + * // with only difference in types: grantor_t vs registrar_t. + * grantor_t reducer_scratchpad(scratchpad, prefix_reducer); + * reducer->exec(reducer_scratchpad); + * } + * }; + * ``` + */ + + +/* namespace with common keys and prefixes */ +namespace names { +enum { + key_none = 0, + key_bnorm_tmp_mean, + key_bnorm_tmp_var, + key_bnorm_tmp_diff_ss, + key_bnorm_tmp_stats, + key_bnorm_reduction, + key_concat_iptrs, + key_concat_istrides, + key_concat_nelems, + key_concat_optrs, + key_conv_adjusted_scales, + key_conv_bia_reduction, + key_conv_gemm_col, + key_conv_gemm_imtr, + key_conv_int_dat_in_acc_dt, + key_conv_padded_bias, + key_conv_rtus_space, + key_conv_tr_diff_dst, + key_conv_tr_diff_dst_bctx, + key_conv_tr_src, + key_conv_tr_src_bctx, + key_conv_wei_reduction, + key_conv_wei_bia_reduction, + key_conv_wei_bia_reduction_bctx, + key_iprod_int_dat_in_acc_dt, + key_reducer_space, + key_reducer_space_bctx, + key_reorder_wino_plain, + key_reorder_wino_transform_space, + key_reorder_rnn_weights_quantization, + key_reorder_rnn_weights_reduction, + key_rnn_space, + key_rnn_ptrs_bia, + key_rnn_ptrs_wei_layer, + key_rnn_ptrs_wei_iter, + key_softmax_reduction, + key_wino_U, + key_wino_V, + key_wino_M, + key_barrier, +}; + +enum { + prefix_none = 0, + prefix_reducer_bia, + prefix_reducer_wei, +}; +} + +// level 0: 00 00 00 xxx +// level 1: 00 00 aa xxx +// level 2: 00 aa bb xxx +// level 3: aa bb cc xxx +// max # of levels: 3 + 1 (base_level) +// here: +// xxx : [1 .. MAX_KEY) : key +// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3 + +using key_t = uint32_t; +enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), }; + +/// generates global key based on a prefix and a local key +inline key_t make_key(key_t prefix, key_t key) { return prefix + key; } + +/// generates global prefix based on the global parent and the local ones +inline key_t make_prefix(key_t parent_prefix, key_t prefix) +{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; } + +struct registrar_t; +struct grantor_t; + +struct registry_t { + void book(const key_t &key, size_t size, size_t alignment) { + if (size == 0) return; + assert(offset_map_.count(key) == 0); + + size = utils::rnd_up(size, minimal_alignment); + alignment = nstl::max(alignment, minimal_alignment); + offset_map_[key] = entry_t{size_, size, alignment}; + + size_ += size + alignment - minimal_alignment; + } + + void *get(const key_t &key, void *base_ptr) const { + if (base_ptr == nullptr) { assert(size() == 0); return nullptr; } + if (offset_map_.count(key) != 1) return nullptr; + + const auto &e = offset_map_.at(key); + base_ptr = utils::align_ptr(base_ptr, minimal_alignment); + char *ptr = (char *)base_ptr + e.offset; + return utils::align_ptr(ptr, e.alignment); + } + + size_t size() const + { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; } + + registrar_t registrar(); + grantor_t grantor(void *base_ptr) const; + +protected: + enum { minimal_alignment = 64 }; + struct entry_t { size_t offset, size, alignment; }; + + std::unordered_map offset_map_; + size_t size_ = 0; +}; + +struct registrar_t { + enum { default_alignment = 64 }; + + registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {} + registrar_t(registrar_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) {} + + void book(const key_t &key, size_t size, + size_t alignment = default_alignment) + { registry_.book(make_key(prefix_, key), size, alignment); } + +protected: + registry_t ®istry_; + const key_t prefix_; +}; + +struct grantor_t { + grantor_t(const registry_t ®istry, void *base_ptr) + : registry_(registry), prefix_(0), base_ptr_(base_ptr) {} + grantor_t(const grantor_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) + , base_ptr_(parent.base_ptr_) {} + + template T *get(const key_t &key) const + { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); } + +protected: + const registry_t ®istry_; + const key_t prefix_; + void *base_ptr_; +}; + +inline registrar_t registry_t::registrar() { return registrar_t(*this); } +inline grantor_t registry_t::grantor(void *base_ptr) const +{ return grantor_t(*this, base_ptr); } + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp new file mode 100644 index 0000000000..2ef4a8fddc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#define DPRINT(...) do { \ + int l = snprintf(str + written_len, str_len, __VA_ARGS__); \ + if (l < 0) return l; \ + if ((size_t)l >= str_len) return -1; \ + written_len += l; str_len -= l; \ +} while(0) + +int mkldnn_md2fmt_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1u) + return -1; + + int written_len = 0; + + if (mdesc == nullptr) { + DPRINT("%s::%s::", + mkldnn_dt2str(data_type::undef), + mkldnn_fmt_kind2str(format_kind::undef)); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + DPRINT("%s:", mkldnn_dt2str(md.data_type())); + + bool padded_dims = false, padded_offsets = false; + for (int d = 0; d < md.ndims(); ++d) { + if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true; + if (md.padded_offsets()[d] != 0) padded_offsets = true; + } + bool offset0 = md.offset0(); + DPRINT("%s%s%s:", + padded_dims ? "p" : "", + padded_offsets ? "o" : "", + offset0 ? "0" : ""); + + DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind())); + + if (!md.is_blocking_desc()) { + /* TODO: extend */ + DPRINT("%s:", ""); + } else { + const auto &blk = md.blocking_desc(); + + dims_t blocks; + md.compute_blocks(blocks); + + char dim_chars[MKLDNN_MAX_NDIMS + 1]; + + bool plain = true; + for (int d = 0; d < md.ndims(); ++d) { + dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d; + if (blocks[d] != 1) plain = false; + } + + dims_t strides; + utils::array_copy(strides, blk.strides, md.ndims()); + utils::simultaneous_sort(strides, dim_chars, md.ndims(), + [](dim_t a, dim_t b) { return b - a; }); + + dim_chars[md.ndims()] = '\0'; + DPRINT("%s", dim_chars); + + if (!plain) { + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + DPRINT("%d%c", (int)blk.inner_blks[iblk], + 'a' + (char)blk.inner_idxs[iblk]); + } + } + + DPRINT("%s", ":"); + } + + DPRINT("f%lx", (long)md.extra().flags); + + return written_len; +} + +int mkldnn_md2dim_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1) + return -1; + + int written_len = 0; + + if (mdesc == nullptr || mdesc->ndims == 0) { + DPRINT("%s", ""); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + for (int d = 0; d < md.ndims() - 1; ++d) + DPRINT("%" PRId64 "x", md.dims()[d]); + DPRINT("%" PRId64, md.dims()[md.ndims() - 1]); + + return written_len; +} + +#undef DPRINT diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp new file mode 100644 index 0000000000..16a8f7ea5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp @@ -0,0 +1,365 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#include + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +const char *mkldnn_status2str(mkldnn_status_t v) { + if (v == mkldnn_success) return "success"; + if (v == mkldnn_out_of_memory) return "out_of_memory"; + if (v == mkldnn_try_again) return "try_again"; + if (v == mkldnn_invalid_arguments) return "invalid_arguments"; + if (v == mkldnn_not_ready) return "not_ready"; + if (v == mkldnn_unimplemented) return "unimplemented"; + if (v == mkldnn_iterator_ends) return "iterator_ends"; + if (v == mkldnn_runtime_error) return "runtime_error"; + if (v == mkldnn_not_required) return "not_required"; + assert(!"unknown status"); + return "unknown status"; +} + +const char *mkldnn_dt2str(mkldnn_data_type_t v) { + if (v == mkldnn_data_type_undef) return "undef"; + if (v == mkldnn_f32) return "f32"; + if (v == mkldnn_s32) return "s32"; + if (v == mkldnn_s8) return "s8"; + if (v == mkldnn_u8) return "u8"; + assert(!"unknown dt"); + return "unknown dt"; +} + +const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) { + if (v == mkldnn_format_kind_undef) return "undef"; + if (v == mkldnn_format_kind_any) return "any"; + if (v == mkldnn_blocked) return "blocked"; + if (v == mkldnn_format_kind_wino) return "wino"; + if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed"; + assert(!"unknown fmt_kind"); + return "unknown fmt_kind"; +} + +const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) { + if (v == mkldnn_format_tag_undef) return "undef"; + if (v == mkldnn_format_tag_any) return "format_tag_any"; + if (v == mkldnn_a) return "a"; + if (v == mkldnn_ab) return "ab"; + if (v == mkldnn_abc) return "abc"; + if (v == mkldnn_abcd) return "abcd"; + if (v == mkldnn_abcde) return "abcde"; + if (v == mkldnn_abcdef) return "abcdef"; + if (v == mkldnn_abdec) return "abdec"; + if (v == mkldnn_acb) return "acb"; + if (v == mkldnn_acbde) return "acbde"; + if (v == mkldnn_acdb) return "acdb"; + if (v == mkldnn_acdeb) return "acdeb"; + if (v == mkldnn_ba) return "ba"; + if (v == mkldnn_bac) return "bac"; + if (v == mkldnn_bacd) return "bacd"; + if (v == mkldnn_bcda) return "bcda"; + if (v == mkldnn_cba) return "cba"; + if (v == mkldnn_cdba) return "cdba"; + if (v == mkldnn_cdeba) return "cdeba"; + if (v == mkldnn_decab) return "decab"; + if (v == mkldnn_Abc16a) return "Abc16a"; + if (v == mkldnn_ABc16a16b) return "ABc16a16b"; + if (v == mkldnn_aBc16b) return "aBc16b"; + if (v == mkldnn_ABc16b16a) return "ABc16b16a"; + if (v == mkldnn_Abc4a) return "Abc4a"; + if (v == mkldnn_aBc4b) return "aBc4b"; + if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b"; + if (v == mkldnn_ABc4b4a) return "ABc4b4a"; + if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a"; + if (v == mkldnn_ABc8a8b) return "ABc8a8b"; + if (v == mkldnn_aBc8b) return "aBc8b"; + if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b"; + if (v == mkldnn_ABc8b8a) return "ABc8b8a"; + if (v == mkldnn_Abcd16a) return "Abcd16a"; + if (v == mkldnn_ABcd16a16b) return "ABcd16a16b"; + if (v == mkldnn_aBcd16b) return "aBcd16b"; + if (v == mkldnn_ABcd16b16a) return "ABcd16b16a"; + if (v == mkldnn_aBCd16b16c) return "aBCd16b16c"; + if (v == mkldnn_aBCd16c16b) return "aBCd16c16b"; + if (v == mkldnn_Abcd4a) return "Abcd4a"; + if (v == mkldnn_aBcd4b) return "aBcd4b"; + if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b"; + if (v == mkldnn_ABcd4b4a) return "ABcd4b4a"; + if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c"; + if (v == mkldnn_aBCd4c4b) return "aBCd4c4b"; + if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a"; + if (v == mkldnn_ABcd8a8b) return "ABcd8a8b"; + if (v == mkldnn_aBcd8b) return "aBcd8b"; + if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b"; + if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b"; + if (v == mkldnn_ABcd8b8a) return "ABcd8b8a"; + if (v == mkldnn_aBCd8b8c) return "aBCd8b8c"; + if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c"; + if (v == mkldnn_aBCd8c8b) return "aBCd8c8b"; + if (v == mkldnn_Abcde16a) return "Abcde16a"; + if (v == mkldnn_ABcde16a16b) return "ABcde16a16b"; + if (v == mkldnn_aBcde16b) return "aBcde16b"; + if (v == mkldnn_ABcde16b16a) return "ABcde16b16a"; + if (v == mkldnn_aBCde16b16c) return "aBCde16b16c"; + if (v == mkldnn_aBCde16c16b) return "aBCde16c16b"; + if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c"; + if (v == mkldnn_Abcde4a) return "Abcde4a"; + if (v == mkldnn_aBcde4b) return "aBcde4b"; + if (v == mkldnn_ABcde4b4a) return "ABcde4b4a"; + if (v == mkldnn_aBCde4b4c) return "aBCde4b4c"; + if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c"; + if (v == mkldnn_aBCde4c4b) return "aBCde4c4b"; + if (v == mkldnn_Abcde8a) return "Abcde8a"; + if (v == mkldnn_ABcde8a8b) return "ABcde8a8b"; + if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b"; + if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b"; + if (v == mkldnn_ABcde8b8a) return "ABcde8b8a"; + if (v == mkldnn_aBCde8b8c) return "aBCde8b8c"; + if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c"; + if (v == mkldnn_aBCde8c8b) return "aBCde8c8b"; + if (v == mkldnn_aBcdef16b) return "aBcdef16b"; + if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c"; + if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b"; + if (v == mkldnn_aBcdef4b) return "aBcdef4b"; + if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b"; + if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c"; + if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c"; + if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b"; + if (v == mkldnn_aBdc16b) return "aBdc16b"; + if (v == mkldnn_aBdc4b) return "aBdc4b"; + if (v == mkldnn_aBdc8b) return "aBdc8b"; + if (v == mkldnn_aBdec16b) return "aBdec16b"; + if (v == mkldnn_aBdec4b) return "aBdec4b"; + if (v == mkldnn_aBdec8b) return "aBdec8b"; + if (v == mkldnn_aBdefc16b) return "aBdefc16b"; + if (v == mkldnn_aBdefc4b) return "aBdefc4b"; + if (v == mkldnn_aBdefc8b) return "aBdefc8b"; + if (v == mkldnn_Acb16a) return "Acb16a"; + if (v == mkldnn_Acb4a) return "Acb4a"; + if (v == mkldnn_Acb8a) return "Acb8a"; + if (v == mkldnn_aCBd16b16c) return "aCBd16b16c"; + if (v == mkldnn_aCBde16b16c) return "aCBde16b16c"; + if (v == mkldnn_Acdb16a) return "Acdb16a"; + if (v == mkldnn_Acdb4a) return "Acdb4a"; + if (v == mkldnn_Acdb8a) return "Acdb8a"; + if (v == mkldnn_Acdeb16a) return "Acdeb16a"; + if (v == mkldnn_Acdeb4a) return "Acdeb4a"; + if (v == mkldnn_Acdeb8a) return "Acdeb8a"; + if (v == mkldnn_BAc16a16b) return "BAc16a16b"; + if (v == mkldnn_BAcd16a16b) return "BAcd16a16b"; + if (v == mkldnn_format_tag_last) return "format_tag_last"; + if (v == mkldnn_x) return "x"; + if (v == mkldnn_nc) return "nc"; + if (v == mkldnn_cn) return "cn"; + if (v == mkldnn_ncw) return "ncw"; + if (v == mkldnn_nwc) return "nwc"; + if (v == mkldnn_nchw) return "nchw"; + if (v == mkldnn_nhwc) return "nhwc"; + if (v == mkldnn_chwn) return "chwn"; + if (v == mkldnn_ncdhw) return "ncdhw"; + if (v == mkldnn_ndhwc) return "ndhwc"; + if (v == mkldnn_oi) return "oi"; + if (v == mkldnn_io) return "io"; + if (v == mkldnn_oiw) return "oiw"; + if (v == mkldnn_wio) return "wio"; + if (v == mkldnn_oihw) return "oihw"; + if (v == mkldnn_hwio) return "hwio"; + if (v == mkldnn_ihwo) return "ihwo"; + if (v == mkldnn_iohw) return "iohw"; + if (v == mkldnn_oidhw) return "oidhw"; + if (v == mkldnn_dhwio) return "dhwio"; + if (v == mkldnn_goiw) return "goiw"; + if (v == mkldnn_goihw) return "goihw"; + if (v == mkldnn_hwigo) return "hwigo"; + if (v == mkldnn_giohw) return "giohw"; + if (v == mkldnn_goidhw) return "goidhw"; + if (v == mkldnn_tnc) return "tnc"; + if (v == mkldnn_ntc) return "ntc"; + if (v == mkldnn_ldsnc) return "ldsnc"; + if (v == mkldnn_ldigo) return "ldigo"; + if (v == mkldnn_ldgoi) return "ldgoi"; + if (v == mkldnn_ldgo) return "ldgo"; + if (v == mkldnn_nCdhw16c) return "nCdhw16c"; + if (v == mkldnn_nCdhw4c) return "nCdhw4c"; + if (v == mkldnn_nCdhw8c) return "nCdhw8c"; + if (v == mkldnn_nChw16c) return "nChw16c"; + if (v == mkldnn_nChw4c) return "nChw4c"; + if (v == mkldnn_nChw8c) return "nChw8c"; + if (v == mkldnn_nCw16c) return "nCw16c"; + if (v == mkldnn_nCw4c) return "nCw4c"; + if (v == mkldnn_nCw8c) return "nCw8c"; + if (v == mkldnn_IOw16o16i) return "IOw16o16i"; + if (v == mkldnn_OIw16i16o) return "OIw16i16o"; + if (v == mkldnn_OIw16o16i) return "OIw16o16i"; + if (v == mkldnn_Oiw16o) return "Oiw16o"; + if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i"; + if (v == mkldnn_OIw4i4o) return "OIw4i4o"; + if (v == mkldnn_Oiw4o) return "Oiw4o"; + if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i"; + if (v == mkldnn_OIw8i8o) return "OIw8i8o"; + if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o"; + if (v == mkldnn_OIw8o8i) return "OIw8o8i"; + if (v == mkldnn_Owi16o) return "Owi16o"; + if (v == mkldnn_Owi4o) return "Owi4o"; + if (v == mkldnn_Owi8o) return "Owi8o"; + if (v == mkldnn_IOhw16o16i) return "IOhw16o16i"; + if (v == mkldnn_Ohwi16o) return "Ohwi16o"; + if (v == mkldnn_Ohwi4o) return "Ohwi4o"; + if (v == mkldnn_Ohwi8o) return "Ohwi8o"; + if (v == mkldnn_OIhw16i16o) return "OIhw16i16o"; + if (v == mkldnn_OIhw16o16i) return "OIhw16o16i"; + if (v == mkldnn_Oihw16o) return "Oihw16o"; + if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i"; + if (v == mkldnn_OIhw4i4o) return "OIhw4i4o"; + if (v == mkldnn_Oihw4o) return "Oihw4o"; + if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i"; + if (v == mkldnn_OIhw8i8o) return "OIhw8i8o"; + if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o"; + if (v == mkldnn_OIhw8o8i) return "OIhw8o8i"; + if (v == mkldnn_Odhwi16o) return "Odhwi16o"; + if (v == mkldnn_Odhwi4o) return "Odhwi4o"; + if (v == mkldnn_Odhwi8o) return "Odhwi8o"; + if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o"; + if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i"; + if (v == mkldnn_Oidhw16o) return "Oidhw16o"; + if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o"; + if (v == mkldnn_Oidhw4o) return "Oidhw4o"; + if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i"; + if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o"; + if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i"; + if (v == mkldnn_Goiw16g) return "Goiw16g"; + if (v == mkldnn_gIOw16o16i) return "gIOw16o16i"; + if (v == mkldnn_gOIw16i16o) return "gOIw16i16o"; + if (v == mkldnn_gOIw16o16i) return "gOIw16o16i"; + if (v == mkldnn_gOiw16o) return "gOiw16o"; + if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i"; + if (v == mkldnn_gOIw4i4o) return "gOIw4i4o"; + if (v == mkldnn_gOiw4o) return "gOiw4o"; + if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i"; + if (v == mkldnn_gOIw8i8o) return "gOIw8i8o"; + if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o"; + if (v == mkldnn_gOIw8o8i) return "gOIw8o8i"; + if (v == mkldnn_gOwi16o) return "gOwi16o"; + if (v == mkldnn_gOwi4o) return "gOwi4o"; + if (v == mkldnn_gOwi8o) return "gOwi8o"; + if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i"; + if (v == mkldnn_gOhwi16o) return "gOhwi16o"; + if (v == mkldnn_gOhwi4o) return "gOhwi4o"; + if (v == mkldnn_gOhwi8o) return "gOhwi8o"; + if (v == mkldnn_Goihw16g) return "Goihw16g"; + if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o"; + if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i"; + if (v == mkldnn_gOihw16o) return "gOihw16o"; + if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i"; + if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i"; + if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o"; + if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i"; + if (v == mkldnn_gOihw4o) return "gOihw4o"; + if (v == mkldnn_Goihw8g) return "Goihw8g"; + if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i"; + if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o"; + if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o"; + if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i"; + if (v == mkldnn_gOdhwi16o) return "gOdhwi16o"; + if (v == mkldnn_gOdhwi4o) return "gOdhwi4o"; + if (v == mkldnn_gOdhwi8o) return "gOdhwi8o"; + if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o"; + if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i"; + if (v == mkldnn_gOidhw16o) return "gOidhw16o"; + if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o"; + if (v == mkldnn_gOidhw4o) return "gOidhw4o"; + if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i"; + if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o"; + if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i"; + assert(!"unknown fmt_tag"); + return "unknown fmt_tag"; +} + +const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) { + if (v == mkldnn_prop_kind_undef) return "undef"; + if (v == mkldnn_forward_training) return "forward_training"; + if (v == mkldnn_forward_inference) return "forward_inference"; + if (v == mkldnn_forward_scoring) return "forward_scoring"; + if (v == mkldnn_forward) return "forward"; + if (v == mkldnn_backward) return "backward"; + if (v == mkldnn_backward_data) return "backward_data"; + if (v == mkldnn_backward_weights) return "backward_weights"; + if (v == mkldnn_backward_bias) return "backward_bias"; + assert(!"unknown prop_kind"); + return "unknown prop_kind"; +} + +const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) { + if (v == mkldnn_undefined_primitive) return "undef"; + if (v == mkldnn_reorder) return "reorder"; + if (v == mkldnn_shuffle) return "shuffle"; + if (v == mkldnn_concat) return "concat"; + if (v == mkldnn_sum) return "sum"; + if (v == mkldnn_convolution) return "convolution"; + if (v == mkldnn_deconvolution) return "deconvolution"; + if (v == mkldnn_eltwise) return "eltwise"; + if (v == mkldnn_softmax) return "softmax"; + if (v == mkldnn_pooling) return "pooling"; + if (v == mkldnn_lrn) return "lrn"; + if (v == mkldnn_batch_normalization) return "batch_normalization"; + if (v == mkldnn_inner_product) return "inner_product"; + if (v == mkldnn_rnn) return "rnn"; + assert(!"unknown prim_kind"); + return "unknown prim_kind"; +} + +const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { + if (v == mkldnn_alg_kind_undef) return "undef"; + if (v == mkldnn_convolution_direct) return "convolution_direct"; + if (v == mkldnn_convolution_winograd) return "convolution_winograd"; + if (v == mkldnn_convolution_auto) return "convolution_auto"; + if (v == mkldnn_deconvolution_direct) return "deconvolution_direct"; + if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd"; + if (v == mkldnn_eltwise_relu) return "eltwise_relu"; + if (v == mkldnn_eltwise_tanh) return "eltwise_tanh"; + if (v == mkldnn_eltwise_elu) return "eltwise_elu"; + if (v == mkldnn_eltwise_square) return "eltwise_square"; + if (v == mkldnn_eltwise_abs) return "eltwise_abs"; + if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt"; + if (v == mkldnn_eltwise_linear) return "eltwise_linear"; + if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu"; + if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu"; + if (v == mkldnn_eltwise_logistic) return "eltwise_logistic"; + if (v == mkldnn_pooling_max) return "pooling_max"; + if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding"; + if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding"; + if (v == mkldnn_pooling_avg) return "pooling_avg"; + if (v == mkldnn_lrn_across_channels) return "lrn_across_channels"; + if (v == mkldnn_lrn_within_channel) return "lrn_within_channel"; + if (v == mkldnn_vanilla_rnn) return "vanilla_rnn"; + if (v == mkldnn_vanilla_lstm) return "vanilla_lstm"; + if (v == mkldnn_vanilla_gru) return "vanilla_gru"; + if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset"; + assert(!"unknown alg_kind"); + return "unknown alg_kind"; +} + +const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) { + if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right"; + if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left"; + if (v == mkldnn_bidirectional_concat) return "bidirectional_concat"; + if (v == mkldnn_bidirectional_sum) return "bidirectional_sum"; + if (v == mkldnn_unidirectional) return "unidirectional"; + assert(!"unknown rnn_direction"); + return "unknown rnn_direction"; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp new file mode 100644 index 0000000000..7e5789e2c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_HPP +#define MKLDNN_THREAD_HPP + +#include "utils.hpp" +#include "z_magic.hpp" + +#define MKLDNN_THR_SEQ 0 +#define MKLDNN_THR_OMP 1 +#define MKLDNN_THR_TBB 2 + +/* Ideally this condition below should never happen (if the library is built + * using regular cmake). For the 3rd-party projects that build the library + * from the sources on their own try to guess the right threading... */ +#if !defined(MKLDNN_THR) +# define MKLDNN_THR MKLDNN_THR_TBB +#endif + +#if MKLDNN_THR == MKLDNN_THR_SEQ +#define MKLDNN_THR_SYNC 1 +inline int mkldnn_get_max_threads() { return 1; } +inline int mkldnn_get_num_threads() { return 1; } +inline int mkldnn_get_thread_num() { return 0; } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() {} + +#define PRAGMA_OMP(...) + +#elif MKLDNN_THR == MKLDNN_THR_OMP +#include +#define MKLDNN_THR_SYNC 1 + +inline int mkldnn_get_max_threads() { return omp_get_max_threads(); } +inline int mkldnn_get_num_threads() { return omp_get_num_threads(); } +inline int mkldnn_get_thread_num() { return omp_get_thread_num(); } +inline int mkldnn_in_parallel() { return omp_in_parallel(); } +inline void mkldnn_thr_barrier() { +# pragma omp barrier +} + +#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__)) + +#elif MKLDNN_THR == MKLDNN_THR_TBB +#include "tbb/task_arena.h" +#include "tbb/parallel_for.h" +#define MKLDNN_THR_SYNC 0 + +inline int mkldnn_get_max_threads() +{ return tbb::this_task_arena::max_concurrency(); } +inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); } +inline int mkldnn_get_thread_num() +{ return tbb::this_task_arena::current_thread_index(); } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); } + +#define PRAGMA_OMP(...) + +#endif + +/* MSVC still supports omp 2.0 only */ +#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) +# define collapse(x) +# define PRAGMA_OMP_SIMD(...) +#else +# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) +#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER) + +namespace mkldnn { +namespace impl { + +inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; } + +template +inline void balance211(T n, U team, U tid, T &n_start, T &n_end) { + T n_min = 1; + T &n_my = n_end; + if (team <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else if (n_min == 1) { + // team = T1 + T2 + // n = T1*n1 + T2*n2 (n1 - n2 = 1) + T n1 = utils::div_up(n, (T)team); + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_my = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +} // namespace impl +} // namespace mkldnn + +#include "mkldnn_thread_parallel_nd.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp new file mode 100644 index 0000000000..50f9b29622 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP +#define MKLDNN_THREAD_PARALLEL_ND_HPP + +/* This header must be included by mkldnn_thread.hpp only */ + +/* Functions: + * - parallel(nthr, f) - executes f in parallel using at most + * nthr threads. If nthr equals 0 + * mkldnn_get_max_threads() threads is + * used + * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already + * created threads + * - parallel_nd(dims..., f) - creates a parallel section and then + * calls for_nd + * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then + * calls for_nd (mostly for convenience) + */ + +namespace mkldnn { +namespace impl { + +/* general parallelization */ +template +void parallel(int nthr, F f) { + if (nthr == 0) nthr = mkldnn_get_max_threads(); +#if MKLDNN_THR == MKLDNN_THR_SEQ + assert(nthr == 1); + f(0, 1); +#elif MKLDNN_THR == MKLDNN_THR_OMP + if (nthr == 1) { f(0, 1); return; } +# pragma omp parallel num_threads(nthr) + f(mkldnn_get_thread_num(), mkldnn_get_num_threads()); +#elif MKLDNN_THR == MKLDNN_THR_TBB + if (nthr == 1) { f(0, 1); return; } + tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner()); +#endif +} + +/* for_nd section */ + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, F f) { + T0 start{0}, end{0}; + balance211(D0, nthr, ithr, start, end); + for (T0 d0 = start; d0 < end; ++d0) f(d0); +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } +} + +// Skip a lambda function in the parameter pack. +template +constexpr size_t get_work_amount(const T &v) { return 1; } +template +constexpr size_t get_work_amount(const T &v, Args &&...args) +{ return (size_t)v * get_work_amount(utils::forward(args)...); } + +/* parallel_nd and parallel_nd_in_omp section */ + +#if MKLDNN_THR != MKLDNN_THR_TBB +template +void parallel_nd(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + const bool do_parallel = get_work_amount(utils::forward(args)...) > 1; +# pragma omp parallel if (do_parallel) + { + const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads(); + const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num(); + for_nd(ithr, nthr, utils::forward(args)...); + } +#endif +} +#else // MKLDNN_THR != MKLDNN_THR_TBB + +// gcc 4.8 has a bug with passing parameter pack to lambdas. +// So have to explicitly instantiate all the cases. + +template +void parallel_nd(const T0 &D0, F f) { + const size_t work_amount = (size_t)D0; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(T0(iwork)); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } + }, tbb::static_partitioner()); +} +#endif + +template +void parallel_nd_in_omp(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(), + utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_TBB + assert(!"unsupported parallel_nd_in_omp()"); +#endif +} + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp new file mode 100644 index 0000000000..aa671a0b6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TRAITS_HPP +#define MKLDNN_TRAITS_HPP + +#include +#include + +#include "mkldnn.h" +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +template struct prec_traits {}; /* ::type -> float */ +template struct data_traits {}; /* ::data_type -> f32 */ +template struct typesize_traits {}; /* ::data_type_size -> f32 */ +template struct pkind_traits {}; /* ::desc_type, ::query_d */ + +template <> struct prec_traits { typedef float type; }; +template <> struct prec_traits { typedef int32_t type; }; +template <> struct prec_traits { typedef int8_t type; }; +template <> struct prec_traits { typedef uint8_t type; }; + +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::f32; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::s32; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::s8; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::u8; }; + +template <> struct typesize_traits<4> { typedef float type; }; +template <> struct typesize_traits<2> { typedef int16_t type; }; +template <> struct typesize_traits<1> { typedef uint8_t type; }; + +#define PKIND_TRAITS_INST(op) \ +template <> struct pkind_traits { \ + typedef CONCAT2(op, _desc_t) desc_type; \ + static constexpr query_t query_d = query::CONCAT2(op, _d); \ +} +PKIND_TRAITS_INST(convolution); +PKIND_TRAITS_INST(deconvolution); +PKIND_TRAITS_INST(shuffle); +PKIND_TRAITS_INST(eltwise); +PKIND_TRAITS_INST(softmax); +PKIND_TRAITS_INST(pooling); +PKIND_TRAITS_INST(lrn); +PKIND_TRAITS_INST(batch_normalization); +PKIND_TRAITS_INST(inner_product); +PKIND_TRAITS_INST(rnn); +#undef PKIND_TRAITS_INST + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp new file mode 100644 index 0000000000..f89ea999e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef NSTL_HPP +#define NSTL_HPP + +#include +#include +#include + +#include +#include + +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +void *malloc(size_t size, int alignment); +void free(void *p); + +struct c_compatible { + enum { default_alignment = 64 }; + static void *operator new(size_t sz) { + return malloc(sz, default_alignment); + } + static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; } + static void *operator new[](size_t sz) { + return malloc(sz, default_alignment); + } + static void operator delete(void *p) { free(p); } + static void operator delete[](void *p) { free(p); } +}; + +namespace nstl { + +template +inline const T abs(const T& a) { + return a >= 0 ? a : -a; +} + +template +inline const T& max(const T& a, const T& b) { + return a > b ? a : b; +} + +template +inline const T& min(const T& a, const T& b) { + return a < b ? a : b; +} + +template void swap(T& t1, T& t2) { + T tmp(t1); + t1 = t2; + t2 = tmp; +} + +// Rationale: MKL-DNN needs numeric limits implementation that does not +// generate dependencies on C++ run-time libraries. + +template struct numeric_limits; + +template<> struct numeric_limits { + static constexpr float lowest() { return -FLT_MAX; } + static constexpr float max() { return FLT_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int lowest() { return INT32_MIN; } + static constexpr int max() { return INT32_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int16_t lowest() { return INT16_MIN; } + static constexpr int16_t max() { return INT16_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int8_t lowest() { return INT8_MIN; } + static constexpr int8_t max() { return INT8_MAX; } +}; + +template<> struct numeric_limits { + static constexpr uint8_t lowest() { return 0; } + static constexpr uint8_t max() { return UINT8_MAX; } +}; + +template struct is_integral +{ static constexpr bool value = false; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; + +template struct is_same +{ static constexpr bool value = false; }; +template struct is_same +{ static constexpr bool value = true; }; + +// Rationale: MKL-DNN needs container implementations that do not generate +// dependencies on C++ run-time libraries. +// +// Implementation philosophy: caller is responsible to check if the operation +// is valid. The only functions that have to return status are those that +// depend on memory allocation or similar operations. +// +// This means that e.g. an operator [] does not have to check for boundaries. +// The caller should have checked the boundaries. If it did not we crash and +// burn: this is a bug in MKL-DNN and throwing an exception would not have been +// recoverable. +// +// On the other hand, insert() or resize() or a similar operation needs to +// return a status because the outcome depends on factors external to the +// caller. The situation is probably also not recoverable also, but MKL-DNN +// needs to be nice and report "out of memory" to the users. + +enum nstl_status_t { + success = 0, + out_of_memory +}; + +template class vector: public c_compatible { +private: + std::vector _impl; +public: + typedef typename std::vector::iterator iterator; + typedef typename std::vector::const_iterator const_iterator; + typedef typename std::vector::size_type size_type; + vector() {} + vector(size_type n): _impl(n) {} + vector(size_type n, const T &value): _impl(n, value) {} + template + vector(input_iterator first, input_iterator last): _impl(first, last) {} + ~vector() {} + size_type size() const { return _impl.size(); } + T& operator[] (size_type i) { return _impl[i]; } + const T& operator[] (size_type i) const { return _impl[i]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template + nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end) + { + _impl.insert(pos, begin, end); + return success; + } + void clear() { _impl.clear(); } + void push_back(const T& t) { _impl.push_back(t); } + void resize(size_type count) { _impl.resize(count); } + void reserve(size_type count) { _impl.reserve(count); } +}; + +template class map: public c_compatible { +private: + std::map _impl; +public: + typedef typename std::map::iterator iterator; + typedef typename std::map::const_iterator const_iterator; + typedef typename std::map::size_type size_type; + map() {} + ~map() {} + size_type size() const { return _impl.size(); } + T& operator[](const Key &k) { return _impl[k]; } + const T& operator[](const Key &k) const { return _impl[k]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template + void clear() { _impl.clear(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp new file mode 100644 index 0000000000..be96e654ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t pooling_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l) + && one_of(alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto pd = pooling_desc_t(); + pd.primitive_kind = primitive_kind::pooling; + pd.prop_kind = prop_kind; + pd.alg_kind = alg_kind; + pd.src_desc.ndims = src_desc->ndims; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + pd.diff_src_desc = pd.src_desc = zero_md(); + pd.diff_dst_desc = pd.dst_desc = zero_md(); + + (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc; + (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(pd.strides, strides, sp_dims); + utils::array_copy(pd.kernel, kernel, sp_dims); + utils::array_copy(pd.padding[0], padding_l, sp_dims); + utils::array_copy(pd.padding[1], padding_r, sp_dims); + + pd.padding_kind = padding_kind; + if (one_of(alg_kind, pooling_max, pooling_avg_include_padding, + pooling_avg_exclude_padding)) { + pd.accum_data_type = types::default_accum_data_type( + src_desc->data_type, dst_desc->data_type); + } else { + pd.accum_data_type = dst_desc->data_type; + } + + bool consistency = true + && utils::one_of(src_desc->ndims, 4, 5) + && utils::one_of(dst_desc->ndims, 4, 5) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == dst_desc->dims[1]; + for (int i = 2; i < src_desc->ndims; ++i) + consistency = consistency && ( + (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2] + + padding_r[i - 2]) / strides[i - 2] + 1 + == dst_desc->dims[i]); + if (!consistency) return invalid_arguments; + + *pool_desc = pd; + return success; +} +} + +status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc, + dst_desc, strides, kernel, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_src_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t kernel, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind, + diff_src_desc, diff_dst_desc, strides, kernel, padding_l, + padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp new file mode 100644 index 0000000000..4c9c009412 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef POOLING_PD_HPP +#define POOLING_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct pooling_fwd_pd_t; + +struct pooling_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::pooling; + + pooling_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , ws_md_() + {} + + const pooling_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::pooling_d: + *(const pooling_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common pooling aux functions */ + + dim_t MB() const { return src_desc().dims[0]; } + dim_t C() const { return src_desc().dims[1]; } + + dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; } + dim_t IW() const { return src_desc().dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; } + dim_t OW() const { return dst_desc().dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; } + dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; } + dim_t KW() const { return desc_.kernel[ndims() - 3]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return src_desc().ndims; } + bool is_3d() const { return ndims() == 5; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(src_desc()).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + pooling_desc_t desc_; + const pooling_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t ws_md_; + + void init_default_ws() { + ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); + ws_md_.data_type = indices_data_type(); + } + + data_type_t indices_data_type() const { + /* the simplest way to express 256... */ + const int u8_max = nstl::numeric_limits< + typename prec_traits::type>::max(); + return utils::array_product(desc()->kernel, ndims()) <= u8_max + ? data_type::u8 : data_type::s32; + } + +private: + const memory_desc_t &src_desc() const + { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; } + const memory_desc_t &dst_desc() const + { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; } +}; + +struct pooling_fwd_pd_t: public pooling_pd_t { + typedef pooling_fwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_fwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } + +protected: + memory_desc_t src_md_; + memory_desc_t dst_md_; + + virtual status_t set_default_params() { + if (dst_md()->format_kind != format_kind::any) + return status::success; + + if (src_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(dst_md_, + src_md_.format_desc.blocking); + } +}; + +struct pooling_bwd_pd_t: public pooling_pd_t { + typedef pooling_bwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_bwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 1 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t diff_dst_md_; + + virtual status_t set_default_params() { + if (diff_src_md()->format_kind != format_kind::any) + return status::success; + + if (diff_dst_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(diff_src_md_, + diff_dst_md_.format_desc.blocking); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp new file mode 100644 index 0000000000..fdf6522f62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "primitive.hpp" +#include "type_helpers.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::primitive_kind; + +namespace { +// XXX: this is a huge hammer. This disables all and any msan checks on +// primitives outputs. +// +// A proper approach would be an implementation-specific unpoisoning. +void unpoison_outputs(const exec_args_t &args) { + for(const auto &arg: args) { + if (arg.second.is_const) continue; + auto *mem = arg.second.mem; + void *p; + mem->get_data_handle(&p); + size_t s = memory_desc_wrapper(*mem->md()).size(); + msan_unpoison(p, s); + } +} +} + +status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) { + if (primitive_desc) delete primitive_desc; + return success; +} + +status_t mkldnn_primitive_create(primitive_t **primitive, + const primitive_desc_t *primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return primitive_desc->create_primitive(primitive); +} + +status_t mkldnn_primitive_execute(const primitive_t *primitive, + stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) { + bool ok = true + && !utils::any_null(primitive, stream) + && primitive->engine() == stream->engine() + && IMPLICATION(nargs > 0, c_args != nullptr); + if (!ok) return invalid_arguments; + + exec_args_t args; + status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args); + if (status != status::success) return status; + + exec_ctx_t ctx(stream, std::move(args)); + + if (mkldnn_verbose()->level) { + double ms = get_msec(); + status = primitive->execute(ctx); + ms = get_msec() - ms; + printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms); + fflush(0); + } else { + status = primitive->execute(ctx); + } + + if (msan_enabled) unpoison_outputs(ctx.args()); + + return status; +} + +status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive, + const primitive_desc_t **primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return safe_ptr_assign(*primitive_desc, + primitive->pd()); +} + +status_t mkldnn_primitive_destroy(primitive_t *primitive) { + if (primitive != nullptr) + delete primitive; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp new file mode 100644 index 0000000000..3b506d6d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp @@ -0,0 +1,76 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_HPP +#define PRIMITIVE_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "primitive_exec_types.hpp" + +/** \brief A pure virtual primitive class + * + * Primitive contains links to its inputs & outputs, though it does not track + * their readiness on execution step. + * + * @remark @b Rational. + * Dependencies are essential through-out the whole MKL-DNN library, so it + * makes sense to include them on the very low level. On the other hand, + * tracking them should be a task for corresponding essence, like scheduler, + * stream or whatever. Primitive itself should know nothing about the + * environment it is running in. + * + * @note + * To make user experience better we should provide API which allows + * achieving the best (or good enough) performance when creating primitives + * in natural order: i.e. from bottom to top for forward pass and from top to + * bottom for backward pass. Please consider restriction [1] in Level 0. + */ +struct mkldnn_primitive: public mkldnn::impl::c_compatible { + mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd) + : pd_(pd->clone()) {} + virtual ~mkldnn_primitive() { delete pd_; } + + /** returns primitive's engine */ + mkldnn::impl::engine_t *engine() const { return pd_->engine(); } + /** returns primitive's inputs */ + const mkldnn::impl::primitive_desc_t *pd() const { return pd_; } + /** returns primitive's kind */ + mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); } + + /** executes primitive with execution context @p ctx */ + virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx) + const = 0; + +protected: + const mkldnn::impl::primitive_desc_t *pd_; + +private: + mkldnn_primitive() = delete; + mkldnn_primitive(const mkldnn_primitive &) = delete; + mkldnn_primitive(mkldnn_primitive &&) = delete; + mkldnn_primitive &operator=(const mkldnn_primitive &) = delete; + mkldnn_primitive &operator=(mkldnn_primitive &&) = delete; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp new file mode 100644 index 0000000000..9fd638842c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +namespace mkldnn { +namespace impl { + +status_t scales_t::set(dim_t count, int mask, const float *scales) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + scales_ = scales_buf_; + utils::array_set(scales_, scales[0], scales_buf_size); + } else { + scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); + if (scales_ == nullptr) + return status::out_of_memory; + + for (dim_t c = 0; c < count_; ++c) + scales_[c] = scales[c]; + } + + return status::success; +} + +} +} + +status_t post_ops_t::append_sum(float scale) { + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::sum; + entry_[len_].sum.scale = scale; + + len_++; + + return success; +} + +status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, + float beta) { + using namespace mkldnn::impl::alg_kind; + bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic); + if (!known_alg) + return invalid_arguments; + + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::eltwise; + entry_[len_].eltwise.scale = scale; + entry_[len_].eltwise.alg = alg; + entry_[len_].eltwise.alpha = alpha; + entry_[len_].eltwise.beta = beta; + + len_++; + + return success; +} + +status_t primitive_attr_t::set_scratchpad_mode( + scratchpad_mode_t scratchpad_mode) { + using namespace mkldnn::impl::scratchpad_mode; + + const bool ok = one_of(scratchpad_mode, library, user); + if (!ok) + return invalid_arguments; + + scratchpad_mode_ = scratchpad_mode; + return success; +} + +status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { + this->post_ops_ = post_ops; + return success; +} + +/* Public C API */ + +status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) { + if (attr == nullptr) + return invalid_arguments; + + return safe_ptr_assign(*attr, + new mkldnn_primitive_attr); +} + +status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr, + const primitive_attr_t *existing_attr) { + if (any_null(attr, existing_attr)) + return invalid_arguments; + + return safe_ptr_assign(*attr, + existing_attr->clone()); +} + +status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) { + if (attr) + delete attr; + + return success; +} + +status_t mkldnn_primitive_attr_get_scratchpad_mode( + const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { + if (any_null(attr, scratchpad_mode)) + return invalid_arguments; + + *scratchpad_mode = attr->scratchpad_mode_; + + return success; +} + +status_t mkldnn_primitive_attr_set_scratchpad_mode( + primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { + if (any_null(attr)) + return invalid_arguments; + + return attr->set_scratchpad_mode(scratchpad_mode); +} + +status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr, + dim_t *count, int *mask, const float **scales) { + if (any_null(attr, count, mask, scales)) + return invalid_arguments; + + *count = attr->output_scales_.count_; + *mask = attr->output_scales_.mask_; + *scales = attr->output_scales_.scales_; + + return success; +} + +status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr, + dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_scales_.set(count, mask, scales); +} + +status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr, + const post_ops_t **post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + *post_ops = &attr->post_ops_; + return success; +} + +status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr, + const post_ops_t *post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + return attr->set_post_ops(*post_ops); +} + +status_t mkldnn_post_ops_create(post_ops_t **post_ops) { + if (post_ops == nullptr) + return invalid_arguments; + + return safe_ptr_assign(*post_ops, new mkldnn_post_ops); +} + +status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) { + if (post_ops) + delete post_ops; + + return success; +} + +int mkldnn_post_ops_len(const post_ops_t *post_ops) { + if (post_ops) + return post_ops->len_; + + return 0; +} + +primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops, + int index) { + bool ok = post_ops && 0 <= index && index < post_ops->len_; + if (!ok) + return primitive_kind::undefined; + + return post_ops->entry_[index].kind; +} + +status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_sum(scale); +} + +namespace { +bool simple_get_params_check(const post_ops_t *post_ops, int index, + primitive_kind_t kind) { + bool ok = true + && post_ops != nullptr + && 0 <= index + && index < post_ops->len_ + && post_ops->entry_[index].kind == kind; + return ok; +} +} + +status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index, + float *scale) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::sum) + && !any_null(scale); + if (!ok) + return invalid_arguments; + + *scale = post_ops->entry_[index].sum.scale; + return success; +} + +status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale, + alg_kind_t kind, float alpha, float beta) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_eltwise(scale, kind, alpha, beta); +} + +status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops, + int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::eltwise) + && !any_null(scale, alpha, beta); + if (!ok) + return invalid_arguments; + + const auto &e = post_ops->entry_[index].eltwise; + *scale = e.scale; + *alg = e.alg; + *alpha = e.alpha; + *beta = e.beta; + + return success; +} + +status_t mkldnn_primitive_attr_set_rnn_data_qparams( + primitive_attr_t *attr, const float scale, const float shift) { + if (attr == nullptr) + return invalid_arguments; + + return attr->rnn_data_qparams_.set(scale, shift); +} + +status_t mkldnn_primitive_attr_set_rnn_weights_qparams( + primitive_attr_t *attr, dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->rnn_weights_qparams_.set(count, mask, scales); +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp new file mode 100644 index 0000000000..e2130c7ab1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_ATTR_HPP +#define PRIMITIVE_ATTR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_data_qparams_t : public c_compatible { + rnn_data_qparams_t() : scale_(1.), shift_(0.) {} + bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } + + status_t set(float scale, float shift) { + scale_ = scale; + shift_ = shift; + return status::success; + } + + float scale_; + float shift_; +}; + +struct scales_t: public c_compatible { + scales_t(): count_(1), mask_(0), scales_(scales_buf_) + { set(1.); } + + scales_t(const scales_t &rhs): scales_t() + { set(rhs.count_, rhs.mask_, rhs.scales_); } + + ~scales_t() { cleanup(); } + + scales_t &operator=(const scales_t &rhs) { + if (&rhs == this) + return *this; + status_t status = set(rhs.count_, rhs.mask_, rhs.scales_); + assert(status == status::success); + (void)status; + return *this; + } + + bool has_default_values() const { + for (dim_t c = 0; c < count_; ++c) { + if(scales_[c] != 1.) return false; + } + return true; + } + + status_t set(dim_t count, int mask, const float *scales); + status_t set(float single_scale) { return this->set(1, 0, &single_scale); } + + dim_t count_; + int mask_; + float *scales_; + +private: + enum { scales_buf_size = 16 }; + float scales_buf_[scales_buf_size]; + + void cleanup() { + if (scales_ != scales_buf_ && scales_ != nullptr) + impl::free(scales_); + + count_ = 1; + mask_ = 0; + scales_ = scales_buf_; + } +}; + +} +} + +struct mkldnn_post_ops: public mkldnn::impl::c_compatible { + struct entry_t { + struct eltwise_t { + mkldnn::impl::alg_kind_t alg; + float scale, alpha, beta; + }; + + mkldnn::impl::primitive_kind_t kind; + union { + struct { float scale; } sum; + eltwise_t eltwise; + }; + + bool is_eltwise(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::eltwise + && IMPLICATION(require_scale_one, eltwise.scale == 1.f); + } + + bool is_relu(bool require_scale_one = true, + bool require_nslope_zero = true) const { + using namespace mkldnn::impl; + return is_eltwise(require_scale_one) + && eltwise.alg == alg_kind::eltwise_relu + && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f); + } + + bool is_sum(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::sum + && IMPLICATION(require_scale_one, sum.scale == 1.f); + } + }; + + mkldnn_post_ops(): len_(0) {} + + mkldnn::impl::status_t append_sum(float scale); + mkldnn::impl::status_t append_eltwise(float scale, + mkldnn::impl::alg_kind_t alg, float alpha, float beta); + + int find(mkldnn::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len_; + stop = mkldnn::impl::nstl::min(stop, len_); + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) return idx; + return -1; + } + + bool has_default_values() const { return len_ == 0; } + + bool contain(mkldnn::impl::primitive_kind_t kind, int index) const + { return find(kind, index, index + 1) == index; } + + enum { capacity = 4 }; + + int len_; + entry_t entry_[capacity]; +}; + +struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible { + mkldnn_primitive_attr() + : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library) + {} + + mkldnn_primitive_attr *clone() const + { return new mkldnn_primitive_attr(*this); } + + /** Returns true if the attributes have default values. + * + * @note The scratchpad_mode_ is not take into account */ + bool has_default_values() const { + return true + && output_scales_.has_default_values() + && post_ops_.has_default_values() + && rnn_data_qparams_.has_default_values() + && rnn_weights_qparams_.has_default_values(); + } + + mkldnn::impl::status_t set_scratchpad_mode( + mkldnn::impl::scratchpad_mode_t scratchpad_mode); + mkldnn::impl::status_t set_post_ops( + const mkldnn::impl::post_ops_t &post_ops); + + mkldnn::impl::scratchpad_mode_t scratchpad_mode_; + mkldnn::impl::scales_t output_scales_; + mkldnn::impl::post_ops_t post_ops_; + mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_; + mkldnn::impl::scales_t rnn_weights_qparams_; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp new file mode 100644 index 0000000000..723c41e05a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t primitive_desc_t::query(query_t what, int idx, void *result) const { + auto safe_ret_md = [&](const memory_desc_t *_) { + if (_ == nullptr) return not_required; + *(const memory_desc_t **)result = _; + return success; + }; + + switch (what) { + case query::engine: *(engine_t**)result = engine(); break; + case query::primitive_kind: *(primitive_kind_t*)result = kind(); break; + + case query::scratchpad_engine: + *(engine_t**)result = scratchpad_engine(); break; + + case query::memory_consumption_s64: + *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break; + + case query::op_d: + if (idx != 0 || op_desc() == nullptr) return invalid_arguments; + *(const_c_op_desc_t *)result + = static_cast(op_desc()); break; + + case query::src_md: return safe_ret_md(src_md(idx)); + case query::diff_src_md: return safe_ret_md(diff_src_md(idx)); + case query::dst_md: return safe_ret_md(dst_md(idx)); + case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx)); + case query::weights_md: return safe_ret_md(weights_md(idx)); + case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx)); + case query::workspace_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(workspace_md(idx)); + case query::scratchpad_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(scratchpad_md(idx)); + + case query::num_of_inputs_s32: *(int*)result = n_inputs(); break; + case query::num_of_outputs_s32: *(int*)result = n_outputs(); break; + + case query::impl_info_str: *(const char **)result = name(); break; + + default: return unimplemented; + } + return success; +} + +status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc, + const primitive_attr_t **attr) { + if (utils::any_null(primitive_desc, attr)) + return invalid_arguments; + + *attr = primitive_desc->attr(); + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp new file mode 100644 index 0000000000..536dcfa1d0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp @@ -0,0 +1,174 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_DESC_HPP +#define PRIMITIVE_DESC_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "primitive_attr.hpp" +#include "verbose.hpp" + +struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { + using md_t = mkldnn::impl::memory_desc_t; + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), kind_(kind) { info_[0] = '\0'; } + + virtual mkldnn_primitive_desc *clone() const = 0; + virtual ~mkldnn_primitive_desc() {} + + const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } + mkldnn::impl::engine_t *engine() const { return engine_; } + mkldnn::impl::primitive_kind_t kind() const { return kind_; } + + virtual void init_info() {} + const char *info() const { return info_; } + + mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() + { return scratchpad_registry_; } + const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const + { return scratchpad_registry_; } + virtual mkldnn::impl::engine_t *scratchpad_engine() const + { return engine_; } + + virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } + + enum class arg_usage_t { unused, input, output }; + virtual arg_usage_t arg_usage( + mkldnn::impl::primitive_arg_index_t arg) const { + using mkldnn::impl::types::is_zero_md; + if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) + return arg_usage_t::output; + return arg_usage_t::unused; + } + +# define DECLARE_MD_STUB(stub) \ + virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ + { return nullptr; } + + DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); + DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); + DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); + DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); + DECLARE_MD_STUB(workspace_md); +# undef DECLARE_MD_STUB + + const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { + return idx == 0 ? &scratchpad_md_ : nullptr; + } + + virtual void init_scratchpad_md() { + auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); + mkldnn::impl::dims_t dims = { size }; + mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, + mkldnn::impl::data_type::u8, mkldnn_x); + } + + /** returns the scratchpad size for the given scratchpad mode. */ + mkldnn::impl::dim_t scratchpad_size( + mkldnn::impl::scratchpad_mode_t mode) const { + if (mode != attr_.scratchpad_mode_) return 0; + return scratchpad_registry().size(); + } + + virtual int n_inputs() const { return 0; } + virtual int n_outputs() const { return 0; } + + virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, + void *result) const; + + virtual mkldnn::impl::status_t create_primitive( + mkldnn::impl::primitive_t **primitive) const = 0; + + virtual const char *name() const { return "mkldnn_primitive_desc"; } + + /* static magic */ + + template + static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, + const mkldnn::impl::op_desc_t *adesc, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_desc_t *hint_fwd) { + using namespace mkldnn::impl; + using namespace mkldnn::impl::status; + using pd_op_desc_t = typename pkind_traits::desc_type; + if (adesc->kind != pd_t::base_pkind) return invalid_arguments; + assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); + auto hint = + reinterpret_cast(hint_fwd); + auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->init_info(); + _pd->init_scratchpad_md(); + *pd = _pd; + return success; + } + +protected: + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_attr_t attr_; + mkldnn::impl::primitive_kind_t kind_; + + mkldnn::impl::memory_desc_t scratchpad_md_; + + char info_[MKLDNN_VERBOSE_BUF_LEN]; + + mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; + +protected: + /** compares ws between fwd_pd and this (make sense to use for bwd_pd) + * Expectation: this already set workspace, and this workspace should + * exactly match the one from fwd_pd */ + bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { + using namespace mkldnn::impl; + if (!workspace_md()) return true; // the impl lives fine w/o workspace + return fwd_pd && fwd_pd->workspace_md() + && *fwd_pd->workspace_md() == *workspace_md(); + } +}; + +#define DECLARE_COMMON_PD_t(impl_name, ...) \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual const char *name() const override { return impl_name; } +#define DECLARE_COMMON_PD_T(impl_name, ...) \ + DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp new file mode 100644 index 0000000000..43e5a31ef3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "memory.hpp" +#include "primitive.hpp" +#include "primitive_exec_types.hpp" + +namespace mkldnn { +namespace impl { + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args) { + using namespace status; + + if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments; + + int n_inputs = 0; + int n_outputs = 0; + + for (int i = 0; i < nargs; ++i) { + primitive_arg_index_t arg = c_args[i].arg; + auto *mem = c_args[i].memory; + + switch (pd->arg_usage(arg)) { + case primitive_desc_t::arg_usage_t::input: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, true}; + n_inputs++; + break; + case primitive_desc_t::arg_usage_t::output: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, false}; + n_outputs++; + break; + case primitive_desc_t::arg_usage_t::unused: + break; + } + } + + bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md()); + + if (n_inputs != pd->n_inputs()) return invalid_arguments; + if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0)) + return invalid_arguments; + + return success; +} + +const void *exec_ctx_t::input(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +void *exec_ctx_t::output(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(!ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const { + assert(args_.count(arg) == 1); + const auto ma = args_.at(arg); + assert(!ma.is_const); + return ma.mem; +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp new file mode 100644 index 0000000000..0645891da7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_EXEC_TYPES_HPP +#define PRIMITIVE_EXEC_TYPES_HPP + +#include + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct memory_arg_t { + memory_t *mem; + bool is_const; +}; + +using exec_args_t = std::unordered_map; + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args); + +/** Primitive execution context (helps passing stream, memories, and events. */ +struct exec_ctx_t { + exec_ctx_t(const exec_ctx_t &) = default; + exec_ctx_t(exec_ctx_t &&) = default; + + exec_ctx_t(stream_t *stream): stream_(stream) {} + exec_ctx_t(stream_t *stream, exec_args_t &&args) + : stream_(stream) + , args_(std::move(args)) {} + + stream_t *stream() const { return stream_; } + const exec_args_t &args() const { return args_; } + + /* tentative solution... TODO: replace with functions return memory_t */ + const void *input(primitive_arg_index_t arg) const; + void *output(primitive_arg_index_t arg) const; + + const memory_t *memory(primitive_arg_index_t arg) const; + +private: + stream_t *stream_; + exec_args_t args_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp new file mode 100644 index 0000000000..5a1cd7d379 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_iterator_create( + primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc, + const primitive_attr_t *attr, engine_t *engine, + const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd); + if (it == nullptr) return out_of_memory; + + ++(*it); + if (*it == it->end()) { + delete it; + return unimplemented; + } + + *iterator = it; + return success; +} + +status_t mkldnn_primitive_desc_iterator_next( + primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return invalid_arguments; + ++(*iterator); + return *iterator == iterator->end() ? iterator_ends : success; +} + +primitive_desc_t *mkldnn_primitive_desc_iterator_fetch( + const primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return nullptr; + return *(*iterator); +} + +status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc, + const primitive_desc_t *existing_primitive_desc) { + if (utils::any_null(primitive_desc, existing_primitive_desc)) + return invalid_arguments; + return safe_ptr_assign(*primitive_desc, + existing_primitive_desc->clone()); +} + +status_t mkldnn_primitive_desc_iterator_destroy( + primitive_desc_iterator_t *iterator) { + if (iterator != nullptr) + delete iterator; + return success; +} + +status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc, + const_c_op_desc_t c_op_desc, const primitive_attr_t *attr, + engine_t *engine, const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd); + ++it; + if (it == it.end()) return unimplemented; + + return safe_ptr_assign(*primitive_desc, *it); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp new file mode 100644 index 0000000000..4e88ab3aa5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#ifndef PRIMITIVE_ITERATOR_HPP +#define PRIMITIVE_ITERATOR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible { + using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc, + const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd) + : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc) + , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd) + , impl_list_(engine_->get_implementation_list()), last_idx_(0) + { + while (impl_list_[last_idx_] != nullptr) ++last_idx_; + } + ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; } + + bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return idx_ == rhs.idx_ && engine_ == rhs.engine_; } + bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return !operator==(rhs); } + + mkldnn::impl::primitive_desc_iterator_t end() const + { return mkldnn_primitive_desc_iterator(engine_, last_idx_); } + + mkldnn::impl::primitive_desc_iterator_t &operator++() { + if (pd_) { delete pd_; pd_ = nullptr; } + while (++idx_ != last_idx_) { + auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_, + hint_fwd_pd_); + if (s == mkldnn::impl::status::success) break; + } + return *this; + } + + mkldnn::impl::primitive_desc_t *operator*() const { + if (*this == end() || pd_ == nullptr) return nullptr; + return pd_->clone(); + } + +protected: + int idx_; + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_desc_t *pd_; + const mkldnn::impl::op_desc_t *op_desc_; + const mkldnn::impl::primitive_attr_t attr_; + const mkldnn::impl::primitive_desc_t *hint_fwd_pd_; + const pd_create_f *impl_list_; + int last_idx_; + +private: + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx) + : idx_(last_idx), engine_(engine), pd_(nullptr) + , op_desc_(nullptr), hint_fwd_pd_(nullptr) + , impl_list_(nullptr), last_idx_(last_idx) {} +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp new file mode 100644 index 0000000000..835cd73581 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/query.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc, + query_t what, int index, void *result) { + if (any_null(primitive_desc, result)) + return invalid_arguments; + + return primitive_desc->query(what, index, result); +} + +const memory_desc_t *mkldnn_primitive_desc_query_md( + const primitive_desc_t *primitive_desc, query_t what, int index) { + const memory_desc_t *res_md = nullptr; + bool args_ok = true + && primitive_desc != nullptr + && (what & query::some_md) == query::some_md + && what != query::some_md + && mkldnn_primitive_desc_query(primitive_desc, + what, index, &res_md) == success; + return args_ok ? res_md : nullptr; +} + +int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc, + query_t what, int index) { + int res_s32; + bool args_ok = primitive_desc != nullptr + && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32) + && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32) + == success; + return args_ok ? res_s32 : 0; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp new file mode 100644 index 0000000000..d11f1a0361 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "reorder_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_reorder_primitive_desc_create( + primitive_desc_t **reorder_pd, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md, + const primitive_attr_t *attr) { + if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md)) + return invalid_arguments; + + auto s_ek = src_engine->kind(); + auto d_ek = dst_engine->kind(); + if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek))) + return invalid_arguments; + + auto r_pd = reinterpret_cast(reorder_pd); + auto s_mdw = memory_desc_wrapper(*src_md); + auto d_mdw = memory_desc_wrapper(*dst_md); + + if (!s_mdw.consistent_with(d_mdw)) + return invalid_arguments; + + auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + for (auto r = e->get_reorder_implementation_list(); *r; ++r) { + if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md) + == success) { + (*r_pd)->init_info(); + (*r_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp new file mode 100644 index 0000000000..963cb0f58a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REORDER_PD_HPP +#define REORDER_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct reorder_pd_t: public primitive_desc_t { + reorder_pd_t(engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) + : primitive_desc_t(engine, attr, primitive_kind::reorder) + , src_engine_(src_engine) + , dst_engine_(dst_engine) + , scratchpad_engine_(nullptr) + , src_md_(*src_md) + , dst_md_(*dst_md) + {} + + virtual const op_desc_t *op_desc() const override { return nullptr; } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_FROM) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_TO) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + float alpha() const { return attr()->output_scales_.scales_[0]; } + float beta() const { + const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); + return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; + } + virtual mkldnn::impl::engine_t *scratchpad_engine() const override + { return scratchpad_engine_; } + +protected: + engine_t *src_engine_; + engine_t *dst_engine_; + engine_t *scratchpad_engine_; + + memory_desc_t src_md_; + memory_desc_t dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp new file mode 100644 index 0000000000..36967431a6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu/gemm/os_blas.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::utils; + +namespace { +memory_desc_t copy_maybe_null(const memory_desc_t *md) { + return md ? *md : zero_md(); +} + +rnn_desc_t zero_rnn_desc() { + auto rd = rnn_desc_t(); + rd.src_layer_desc = zero_md(); + rd.src_iter_desc = zero_md(); + rd.weights_layer_desc = zero_md(); + rd.weights_iter_desc = zero_md(); + rd.bias_desc = zero_md(); + rd.dst_layer_desc = zero_md(); + rd.dst_iter_desc = zero_md(); + rd.diff_src_layer_desc = zero_md(); + rd.diff_src_iter_desc = zero_md(); + rd.diff_weights_layer_desc = zero_md(); + rd.diff_weights_iter_desc = zero_md(); + rd.diff_bias_desc = zero_md(); + rd.diff_dst_layer_desc = zero_md(); + rd.diff_dst_iter_desc = zero_md(); + return rd; +} +} + +/* Public C Api */ + +status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, + unsigned int flags, float alpha, float clipping) { + using namespace mkldnn::impl::alg_kind; + + bool args_ok = true + && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, + gru_linear_before_reset) + && IMPLICATION(cell_kind == vanilla_rnn, + one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); + if (!args_ok) + return invalid_arguments; + + auto rcd = mkldnn_rnn_cell_desc_t(); + + rcd.cell_kind = cell_kind; + rcd.activation_kind = act_f; + rcd.flags = flags; + rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; + rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; + + *rnn_cell_desc = rcd; + + return success; +} + +int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 3; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; + case mkldnn::impl::alg_kind::vanilla_lstm: return 4; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 1; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; + case mkldnn::impl::alg_kind::vanilla_lstm: return 2; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, + prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + using namespace data_type; + data_type_t src_layer_dt = src_layer_desc->data_type; + data_type_t dst_layer_dt = dst_layer_desc->data_type; + data_type_t weights_iter_dt = weights_iter_desc->data_type; + data_type_t weights_layer_dt = weights_layer_desc->data_type; + + bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, + weights_layer_dt) + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + +#if USE_MKL_PACKED_GEMM + bool is_u8u8u8 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == u8) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == u8) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_f32u8f32 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_inference = prop_kind == prop_kind::forward_inference; + bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; + + return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) + ? success + : unimplemented; +#else + return is_f32 ? success : unimplemented; +#endif +} + +status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, + rnn_direction_t direction, int L, int D, int T, int N, int S, int G, + int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok; + + // * algorithm specific + args_ok = true + && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, + DIC == SIC); + if (!args_ok) return invalid_arguments; + int extra_bias = + rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; + + // * on num layers + args_ok = true + && L == weights_layer_desc->dims[0] + && L == weights_iter_desc->dims[0] + && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) + && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) + && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); + if (!args_ok) return invalid_arguments; + + // * on num directions + args_ok = true + && D == weights_layer_desc->dims[1] + && D == weights_iter_desc->dims[1] + && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) + && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) + && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); + if (!args_ok) return invalid_arguments; + + // * on num iterations + args_ok = true + && T == src_layer_desc->dims[0] + && T == dst_layer_desc->dims[0]; + if (!args_ok) return invalid_arguments; + + // * on mb + args_ok = true + && N == src_layer_desc->dims[1] + && N == dst_layer_desc->dims[1] + && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); + if (!args_ok) return invalid_arguments; + + // * on num gates + args_ok = true + && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) + && G == weights_layer_desc->dims[3] + && G == weights_iter_desc->dims[3] + && IMPLICATION(!is_zero_md(bias_desc), + G + extra_bias == bias_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on num states + args_ok = true + && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) + && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) + && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on slc + args_ok = true + && SLC == weights_layer_desc->dims[2] + && SLC == src_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on sic + args_ok = true + && SIC == weights_iter_desc->dims[2] + && IMPLICATION(!is_zero_md(src_iter_desc), + SIC == src_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * on dlc + int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; + args_ok = true + && DLC == dlc_multiplier * DIC + && DLC == dst_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on dic + args_ok = true + && DIC == weights_layer_desc->dims[4] + && DIC == weights_iter_desc->dims[4] + && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), + DIC == dst_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * unrolling/fusion conditions + args_ok = true + && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) + && IMPLICATION(T > 1, SIC == DIC); + if (!args_ok) return invalid_arguments; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok = true && rnn_cell_desc != nullptr + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc); + if (!args_ok) return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc)); + + CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, + src_layer_desc, src_iter_desc, weights_layer_desc, + weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); + + // Create the descriptor + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + + *rnn_desc = rd; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, + const memory_desc_t *diff_src_layer_desc, + const memory_desc_t *diff_src_iter_desc, + const memory_desc_t *diff_weights_layer_desc, + const memory_desc_t *diff_weights_iter_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_layer_desc, + const memory_desc_t *diff_dst_iter_desc) { + bool args_ok = true + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc, diff_src_layer_desc, + diff_weights_layer_desc, diff_weights_iter_desc, + diff_dst_layer_desc); + if (!args_ok) + return invalid_arguments; + + auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { + return is_zero_md(a_md) == is_zero_md(b_md); + }; + + args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) + && xnor_md(dst_iter_desc, diff_dst_iter_desc) + && xnor_md(src_iter_desc, diff_src_iter_desc); + if (!args_ok) + return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc); + if (st != success) return st; + + st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, + diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc); + if (st != success) return st; + + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); + rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); + rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); + rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); + rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); + rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); + rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); + + *rnn_desc = rd; + + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp new file mode 100644 index 0000000000..1ee2ba1114 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_PD_HPP +#define RNN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_fwd_pd_t; + +struct rnn_pd_t : public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::rnn; + + rnn_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , src_layer_md_(desc_.src_layer_desc) + , src_iter_md_(desc_.src_iter_desc) + , weights_layer_md_(desc_.weights_layer_desc) + , weights_iter_md_(desc_.weights_iter_desc) + , bias_md_(desc_.bias_desc) + , dst_layer_md_(desc_.dst_layer_desc) + , dst_iter_md_(desc_.dst_iter_desc) + , ws_md_() + {} + + const rnn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::rnn_d: *(const rnn_desc_t **)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &src_layer_md_; + if (index == 1 && with_src_iter()) return &src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_layer_md_; + if (index == 1) return &weights_iter_md_; + if (index == 2 && with_bias()) return &bias_md_; + return nullptr; + } + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &dst_layer_md_; + if (index == 1 && with_dst_iter()) return &dst_iter_md_; + return nullptr; + } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + /* common pooling aux functions */ + + bool is_training() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::backward); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + dim_t T() const { return desc_.src_layer_desc.dims[0]; } + dim_t MB() const { return desc_.src_layer_desc.dims[1]; } + + dim_t L() const { return desc_.weights_layer_desc.dims[0]; } + dim_t D() const { return desc_.weights_layer_desc.dims[1]; } + + dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; } + + dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; } + dim_t G() const { return desc_.weights_layer_desc.dims[3]; } + dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; } + + dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; } + + bool with_bias() const + { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); } + + bool with_src_iter() const + { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); } + + bool with_dst_iter() const + { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); } + + mkldnn::impl::alg_kind_t cell_kind() const + { return desc_.cell_desc.cell_kind; } + mkldnn::impl::alg_kind_t activation_kind() const + { return desc_.cell_desc.activation_kind; } + + bool is_lbr() const + { return cell_kind() == mkldnn_gru_linear_before_reset; } + + mkldnn_rnn_direction_t direction() const { return desc_.direction; } + +protected: + rnn_desc_t desc_; + const rnn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t src_layer_md_; + memory_desc_t src_iter_md_; + memory_desc_t weights_layer_md_; + memory_desc_t weights_iter_md_; + memory_desc_t bias_md_; + memory_desc_t dst_layer_md_; + memory_desc_t dst_iter_md_; + + memory_desc_t ws_md_; +}; + +struct rnn_fwd_pd_t: public rnn_pd_t { + typedef rnn_fwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_fwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC_LAYER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter()) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST_LAYER) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter()) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual int n_inputs() const override + { return 3 + with_bias() + with_src_iter(); } + virtual int n_outputs() const override + { return 1 + with_dst_iter() + is_training(); } +}; + +struct rnn_bwd_pd_t : public rnn_pd_t { + typedef rnn_bwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_bwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_layer_md_(desc_.diff_src_layer_desc) + , diff_src_iter_md_(desc_.diff_src_iter_desc) + , diff_weights_layer_md_(desc_.diff_weights_layer_desc) + , diff_weights_iter_md_(desc_.diff_weights_iter_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_layer_md_(desc_.diff_dst_layer_desc) + , diff_dst_iter_md_(desc_.diff_dst_iter_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER, + MKLDNN_ARG_DIFF_DST_LAYER)) + return arg_usage_t::input; + + if (with_src_iter()) { + if (arg == MKLDNN_ARG_SRC_ITER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC_ITER) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (with_bias()) { + if (arg == MKLDNN_ARG_BIAS) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_BIAS) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER) + && with_dst_iter()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_ITER)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override { + if (index == 0) return &diff_src_layer_md_; + if (index == 1 && with_src_iter()) return &diff_src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *diff_weights_md( + int index = 0) const override { + if (index == 0) return &diff_weights_layer_md_; + if (index == 1) return &diff_weights_iter_md_; + if (index == 2 && with_bias()) return &diff_bias_md_; + return nullptr; + } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override { + if (index == 0) return &diff_dst_layer_md_; + if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_; + return nullptr; + } + + virtual int n_inputs() const override + { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); } + virtual int n_outputs() const override + { return 3 + with_src_iter() + with_bias(); } + +protected: + memory_desc_t diff_src_layer_md_; + memory_desc_t diff_src_iter_md_; + memory_desc_t diff_weights_layer_md_; + memory_desc_t diff_weights_iter_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_layer_md_; + memory_desc_t diff_dst_iter_md_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp new file mode 100644 index 0000000000..6bc14fc72a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "scratchpad.hpp" + +namespace mkldnn { +namespace impl { + +/* Allocating memory buffers on a page boundary to reduce TLB/page misses */ +const size_t page_size = 2097152; + +/* + Implementation of the scratchpad_t interface that is compatible with + a concurrent execution +*/ +struct concurent_scratchpad_t : public scratchpad_t { + concurent_scratchpad_t(size_t size) { + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + + ~concurent_scratchpad_t() { + free(scratchpad_); + } + + virtual char *get() const { + return scratchpad_; + } + +private: + char *scratchpad_; + size_t size_; +}; + +/* + Implementation of the scratchpad_t interface that uses a global + scratchpad +*/ + +struct global_scratchpad_t : public scratchpad_t { + global_scratchpad_t(size_t size) { + if (size > size_) { + if (scratchpad_ != nullptr) free(scratchpad_); + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + reference_count_++; + } + + ~global_scratchpad_t() { + reference_count_--; + if (reference_count_ == 0) { + free(scratchpad_); + scratchpad_ = nullptr; + size_ = 0; + } + } + + virtual char *get() const { + return scratchpad_; + } + +private: + /* + Using thread-local here is unnecessary and even buggy! All threads + actually share the same scratchpad, which is created and queried only + on the main thread. If the scratchpad is queried on some thread other + than the one it was created on (e.g. the application calls the API from + multiple threads), thread-local causes a segfault because the scratchpad + is uninitialized on the current thread. + */ + /*thread_local*/ static char *scratchpad_; + /*thread_local*/ static size_t size_; + /*thread_local*/ static unsigned int reference_count_; +}; + +/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr; +/*thread_local*/ size_t global_scratchpad_t::size_ = 0; +/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0; + + +/* + Scratchpad creation routine +*/ +scratchpad_t *create_scratchpad(size_t size) { +#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC + return new global_scratchpad_t(size); +#else + return new concurent_scratchpad_t(size); +#endif +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp new file mode 100644 index 0000000000..f7a246bc99 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_SCRATCHPAD_HPP +#define COMMON_SCRATCHPAD_HPP + +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct scratchpad_t { + virtual ~scratchpad_t() {} + virtual char *get() const = 0; +}; + +scratchpad_t *create_scratchpad(size_t size); + +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp new file mode 100644 index 0000000000..e32e735224 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, int axis, dim_t group_size) { + bool args_ok = true + && !any_null(shuffle_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward, backward_data) + && axis >= 0 && axis < data_desc->ndims + && group_size > 0 && group_size <= data_desc->dims[axis]; + if (!args_ok) return invalid_arguments; + + auto sd = shuffle_desc_t(); + sd.primitive_kind = primitive_kind::shuffle; + sd.prop_kind = prop_kind; + sd.data_desc = *data_desc; + sd.axis = axis; + sd.group_size = group_size; + + bool consistency = true + && sd.data_desc.dims[axis] % sd.group_size == 0; + if (!consistency) return invalid_arguments; + + *shuffle_desc = sd; + return success; +} +} + +status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis, + dim_t group_size) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis, + group_size); +} + +status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc, + const memory_desc_t *diff_data_desc, int axis, dim_t group_size) { + return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis, + group_size); +} + +// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp new file mode 100644 index 0000000000..cc5553fe7f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SHUFFLE_PD_HPP +#define SHUFFLE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct shuffle_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::shuffle; + + typedef shuffle_pd_t base_class; + typedef shuffle_pd_t hint_class; + + shuffle_pd_t(engine_t *engine, + const shuffle_desc_t *adesc, + const primitive_attr_t *attr, + const shuffle_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const shuffle_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::shuffle_d: + *(const shuffle_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (is_fwd()) { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + } else { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + } + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + /* shuffle aux functions */ + + dim_t MB() const { return data_md()->dims[0]; } + dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; } + dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; } + + int ndims() const { return data_md()->ndims; } + + int axis() const { return desc_.axis; } + dim_t group_size() const { return desc_.group_size; } + dim_t axis_size() const { return data_md()->dims[axis()]; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + const memory_desc_t *data_md() const { return &data_md_; } + +protected: + shuffle_desc_t desc_; + const shuffle_pd_t *hint_fwd_pd_; + memory_desc_t data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp new file mode 100644 index 0000000000..82848e3d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) { + bool args_ok = true + && !any_null(softmax_desc, data_desc) + && 0 <= softmax_axis + && softmax_axis < data_desc->ndims; + if (!args_ok) return invalid_arguments; + + auto sd = softmax_desc_t(); + sd.primitive_kind = primitive_kind::softmax; + sd.prop_kind = prop_kind; + + bool is_bwd = (sd.prop_kind == backward_data); + sd.data_desc = *data_desc; + sd.diff_desc = is_bwd ? *diff_desc : zero_md(); + sd.softmax_axis = softmax_axis; + + *softmax_desc = sd; + return success; +} +} + +status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + int softmax_axis) { + if (!one_of(prop_kind, forward_inference, forward_training)) + return invalid_arguments; + return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis); +} + +status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc, + const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, + int softmax_axis) { + return softmax_desc_init(softmax_desc, prop_kind::backward_data, + data_desc, diff_desc, softmax_axis); +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp new file mode 100644 index 0000000000..8a16ce901c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SOFTMAX_PD_HPP +#define SOFTMAX_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct softmax_fwd_pd_t; + +struct softmax_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::softmax; + + softmax_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const softmax_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::softmax_d: + *(const softmax_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common softmax aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + softmax_desc_t desc_; + const softmax_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct softmax_fwd_pd_t: public softmax_pd_t { + typedef softmax_fwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_fwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct softmax_bwd_pd_t: public softmax_pd_t { + typedef softmax_bwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_bwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp new file mode 100644 index 0000000000..00af8935c0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +/* API */ + +status_t mkldnn_stream_create(stream_t **stream, engine_t *engine, + unsigned flags) { + bool args_ok = true + && !utils::any_null(stream, engine) + && flags == stream_flags::default_flags; + if (!args_ok) + return invalid_arguments; + + return safe_ptr_assign(*stream, new stream_t(engine, flags)); +} + +status_t mkldnn_stream_destroy(stream_t *stream) { + delete stream; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp new file mode 100644 index 0000000000..f010e5f6ed --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef STREAM_HPP +#define STREAM_HPP + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" + +struct mkldnn_stream: public mkldnn::impl::c_compatible { + mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags) + : engine_(engine), flags_(flags) {} + virtual ~mkldnn_stream() {} + + /** returns stream's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + + /** returns stream's kind */ + unsigned flags() const { return flags_; } + +protected: + mkldnn::impl::engine_t *engine_; + unsigned flags_; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp new file mode 100644 index 0000000000..365663c0f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "sum_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds, const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != dims[d]) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto s_pd = reinterpret_cast(sum_pd); + + for (auto s = engine->get_sum_implementation_list(); *s; ++s) { + if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) { + (*s_pd)->init_info(); + (*s_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp new file mode 100644 index 0000000000..80254667df --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SUM_PD_HPP +#define SUM_PD_HPP + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct sum_pd_t: public primitive_desc_t { + sum_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::sum) + , n_(n), dst_md_(*dst_md) + { + scales_.reserve(n_); + for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]); + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + const float *scales() const { return &scales_[0]; } + +protected: + int n_; + nstl::vector scales_; + memory_desc_t dst_md_; + nstl::vector src_mds_; + +protected: + /* inits dst_md_ in simple cases. The call may fail. */ + status_t init() { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(&src_mds_[i]); + if (!src_d.is_blocking_desc() || src_d.is_additional_buffer()) + return status::unimplemented; + } + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain, pick the format of the first input + */ + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (!src_d.is_plain() && src_d.is_blocking_desc()) { + return memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + } + } + + if (src_mds_[0].format_kind != format_kind::blocked) + return status::unimplemented; + + dst_md_ = src_mds_[0]; + + return status::success; + } +}; + +#define DECLARE_SUM_PD_t(impl_name, ...) \ + static status_t create(sum_pd_t **sum_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, const float *scales, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign(*sum_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_SUM_PD_T(impl_name, ...) \ + DECLARE_SUM_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp new file mode 100644 index 0000000000..a408f45980 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TAG_TRAITS_HPP +#define TAG_TRAITS_HPP + +#include + +#include "c_types_map.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +enum class block_dim_t { + _, + _A, _B, + _AB, _BC, +}; + +enum class inner_blk_t { + _, + _4a, _4b, + _8a, _8b, + _16a, _16b, + + _4b4a, _4b4c, _4c4b, + _8a8b, _8b8a, _8b8c, _8c8b, + _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b, + + _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c, +}; + +/** returns the offset within the block for weights blocked over oc and ic */ +template +constexpr int AB_or_BC_blk_off(int x0, int x1) { + using ib = inner_blk_t; + static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b, + ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b, + ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c), + "unexpected inner_blk format"); + return false ? 0 + : (f == ib::_4b4c) ? 4 * x0 + x1 + : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 + : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 + : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 + : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 + : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 + : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1 + : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2 + : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 + : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 + : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4 + : INT_MIN; +} + +template struct inner_blk_traits { + using ib = inner_blk_t; +}; + +template struct tag_traits { + // block_dim_t block_dims; + // inner_blk_t inner_blks; + // int ndims; +}; + +#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ +template <> struct tag_traits { \ + static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \ + static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \ + static constexpr int ndims = _ndims; \ +} + +DECL_TRAITS(a, _, _, 1); +DECL_TRAITS(ab, _, _, 2); +DECL_TRAITS(abc, _, _, 3); +DECL_TRAITS(abcd, _, _, 4); +DECL_TRAITS(abcde, _, _, 5); +DECL_TRAITS(abcdef, _, _, 6); +DECL_TRAITS(abdec, _, _, 5); +DECL_TRAITS(acb, _, _, 3); +DECL_TRAITS(acbde, _, _, 5); +DECL_TRAITS(acdb, _, _, 4); +DECL_TRAITS(acdeb, _, _, 5); +DECL_TRAITS(ba, _, _, 2); +DECL_TRAITS(bac, _, _, 3); +DECL_TRAITS(bacd, _, _, 4); +DECL_TRAITS(bcda, _, _, 4); +DECL_TRAITS(cba, _, _, 3); +DECL_TRAITS(cdba, _, _, 4); +DECL_TRAITS(cdeba, _, _, 5); +DECL_TRAITS(decab, _, _, 5); + +DECL_TRAITS(Abc4a, _A, _4a, 3); +DECL_TRAITS(aBc4b, _B, _4b, 3); +DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3); +DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3); +DECL_TRAITS(Abcd4a, _A, _4a, 4); +DECL_TRAITS(aBcd4b, _B, _4b, 4); +DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4); +DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4); +DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4); +DECL_TRAITS(Abcde4a, _A, _4a, 5); +DECL_TRAITS(aBcde4b, _B, _4b, 5); +DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5); +DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5); +DECL_TRAITS(aBcdef4b, _B, _4b, 6); +DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6); +DECL_TRAITS(aBdc4b, _B, _4b, 4); +DECL_TRAITS(aBdec4b, _B, _4b, 5); +DECL_TRAITS(aBdefc4b, _B, _4b, 6); +DECL_TRAITS(Acb4a, _A, _4a, 3); +DECL_TRAITS(Acdb4a, _A, _4a, 4); +DECL_TRAITS(Acdeb4a, _A, _4a, 5); + +DECL_TRAITS(Abc16a, _A, _16a, 3); +DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(aBc16b, _B, _16b, 3); +DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3); +DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3); +DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3); +DECL_TRAITS(aBc8b, _B, _8b, 3); +DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3); +DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3); +DECL_TRAITS(Abcd16a, _A, _16a, 4); +DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4); +DECL_TRAITS(aBcd16b, _B, _16b, 4); +DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4); +DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4); +DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4); +DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); +DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); +DECL_TRAITS(aBcd8b, _B, _8b, 4); +DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); +DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4); +DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4); +DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4); +DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4); +DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4); +DECL_TRAITS(Abcde16a, _A, _16a, 5); +DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5); +DECL_TRAITS(aBcde16b, _B, _16b, 5); +DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5); +DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5); +DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5); +DECL_TRAITS(Abcde8a, _A, _8a, 5); +DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5); +DECL_TRAITS(aBcde8b, _B, _8b, 5); +DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5); +DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5); +DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5); +DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5); +DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5); +DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5); +DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5); +DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5); +DECL_TRAITS(aBcdef16b, _B, _16b, 6); +DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6); +DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6); +DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6); +DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6); +DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6); +DECL_TRAITS(aBdc16b, _B, _16b, 4); +DECL_TRAITS(aBdc8b, _B, _8b, 4); +DECL_TRAITS(aBdec16b, _B, _16b, 5); +DECL_TRAITS(aBdec8b, _B, _8b, 5); +DECL_TRAITS(aBdefc16b, _B, _16b, 6); +DECL_TRAITS(aBdefc8b, _B, _8b, 6); +DECL_TRAITS(Acb16a, _A, _16a, 3); +DECL_TRAITS(Acb8a, _A, _8a, 3); +DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(Acdb16a, _A, _16a, 4); +DECL_TRAITS(Acdb8a, _A, _8a, 4); +DECL_TRAITS(Acdeb16a, _A, _16a, 5); +DECL_TRAITS(Acdeb8a, _A, _8a, 5); +DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4); + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp new file mode 100644 index 0000000000..4f06368738 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_HELPERS_HPP +#define TYPE_HELPERS_HPP + +#include +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "mkldnn_traits.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "math_utils.hpp" + +namespace mkldnn { +namespace impl { + +template +status_t safe_ptr_assign(T * &lhs, T* rhs) { + if (rhs == nullptr) return status::out_of_memory; + lhs = rhs; + return status::success; +} + +template struct is_subset +{ static constexpr bool value = false; }; +template struct is_subset +{ static constexpr bool value = true; }; +template struct is_subset::value, float>::type> +{ static constexpr bool value = true; }; +#define ISSPEC(t1, t2) template <> \ + struct is_subset { static constexpr bool value = true; } +ISSPEC(int16_t, int32_t); +ISSPEC(int8_t, int32_t); +ISSPEC(uint8_t, int32_t); +ISSPEC(int8_t, int16_t); +ISSPEC(uint8_t, int16_t); +#undef ISSPEC + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs); + +namespace types { + +inline size_t data_type_size(data_type_t data_type) { + using namespace data_type; + switch (data_type) { + case f32: return sizeof(prec_traits::type); + case s32: return sizeof(prec_traits::type); + case s8: return sizeof(prec_traits::type); + case u8: return sizeof(prec_traits::type); + case data_type::undef: + default: assert(!"unknown data_type"); + } + return 0; /* not supposed to be reachable */ +} + +inline format_kind_t format_tag_to_kind(format_tag_t tag) { + switch (tag) { + case format_tag::undef: return format_kind::undef; + case format_tag::any: return format_kind::any; + case format_tag::last: return format_kind::undef; + default: return format_kind::blocked; + } + + assert(!"unreachable"); + return format_kind::undef; +} + +inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs, + const memory_extra_desc_t &rhs) { + return true + && lhs.flags == rhs.flags + && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8, + lhs.compensation_mask == rhs.compensation_mask) + && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust, + lhs.scale_adjust == rhs.scale_adjust); +} + +inline bool blocking_desc_is_equal(const blocking_desc_t &lhs, + const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) { + using mkldnn::impl::utils::array_cmp; + return true + && lhs.inner_nblks == rhs.inner_nblks + && array_cmp(lhs.strides, rhs.strides, ndims) + && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) + && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); +} + +inline bool wino_desc_is_equal(const wino_desc_t &lhs, + const wino_desc_t &rhs) { + return lhs.wino_format == rhs.wino_format + && lhs.alpha == rhs.alpha + && lhs.ic == rhs.ic + && lhs.oc == rhs.oc + && lhs.ic_block == rhs.ic_block + && lhs.oc_block == rhs.oc_block + && lhs.ic2_block == rhs.ic2_block + && lhs.oc2_block == rhs.oc2_block + && lhs.r == rhs.r; +} + +inline bool rnn_packed_desc_is_equal( + const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { + bool ok = true + && lhs.format == rhs.format + && lhs.n_parts == rhs.n_parts + && lhs.offset_compensation == rhs.offset_compensation + && lhs.size == rhs.size + && lhs.n == rhs.n; + if (!ok) + return false; + + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.parts[i] == rhs.parts[i]; + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i]; + return ok; +} + +inline memory_desc_t zero_md() { + auto zero = memory_desc_t(); + return zero; +} + +inline bool is_zero_md(const memory_desc_t *md) { + return md == nullptr || *md == zero_md(); +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t dst_dt) { + using namespace utils; + using namespace data_type; + + if (one_of(f32, src_dt, dst_dt)) return f32; + if (one_of(s32, src_dt, dst_dt)) return s32; + + if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32; + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) { + using namespace utils; + using namespace data_type; + using namespace prop_kind; + + /* prop_kind doesn't matter */ + if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32; + + if (one_of(prop_kind, forward_training, forward_inference)) { + if ((src_dt == u8 || src_dt == s8) + && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8)) + return s32; + } else if (prop_kind == backward_data) { + if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && + one_of(dst_dt, s8, u8)) + return s32; + } + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +} + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) { + using namespace mkldnn::impl::utils; + bool base_equal = true + && lhs.ndims == rhs.ndims + && array_cmp(lhs.dims, rhs.dims, lhs.ndims) + && lhs.data_type == rhs.data_type + && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims) + && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims) + && lhs.offset0 == rhs.offset0 + && lhs.format_kind == rhs.format_kind; + if (!base_equal) return false; + if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false; + if (lhs.format_kind == format_kind::blocked) + return types::blocking_desc_is_equal(lhs.format_desc.blocking, + rhs.format_desc.blocking, lhs.ndims); + else if (lhs.format_kind == format_kind::wino) + return types::wino_desc_is_equal(lhs.format_desc.wino_desc, + rhs.format_desc.wino_desc); + else if (lhs.format_kind == format_kind::rnn_packed) + return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc, + rhs.format_desc.rnn_packed_desc); + return true; +} + +inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { + return !operator==(lhs, rhs); +} + +inline status_t memory_desc_init_by_strides(memory_desc_t &md, + const dims_t strides) { + return mkldnn_memory_desc_init_by_strides( + &md, md.ndims, md.dims, md.data_type, strides); +} + +inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + status_t status = mkldnn_memory_desc_init_by_tag( + &md, md.ndims, md.dims, md.data_type, tag); + if (status != status::success || strides == nullptr) + return status; + + /* TODO: add consistency check */ + + for (int d = 0; d < md.ndims; ++d) + md.format_desc.blocking.strides[d] = strides[d]; + + return status::success; +} + +/** inits memory descriptor based on logical dimensions kept in @p md, and the + * blocking structure @p blk. + * + * @note blk.strides represent the order only (from smaller to bigger) + * + * TODO: move md related functions to one single place + */ +inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md, + const blocking_desc_t &blk) { + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + dim_t block_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; + block_size *= blk.inner_blks[iblk]; + } + + for (int d = 0; d < md.ndims; ++d) { + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + md.padded_offsets[d] = 0; + } + md.offset0 = 0; + + md.format_kind = format_kind::blocked; + auto &mblk = md.format_desc.blocking; + mblk = blk; + + const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy + utils::array_copy(mblk.strides, blk.strides, ndims); + + int perm[MKLDNN_MAX_NDIMS]; + for (int d = 0; d < ndims; ++d) perm[d] = d; + + utils::simultaneous_sort(mblk.strides, perm, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + dim_t stride = block_size; + for (int _d = ndims - 1; _d >= 0; --_d) { + const int d = perm[_d]; + md.format_desc.blocking.strides[d] = stride; + stride *= md.padded_dims[d] / blocks[d]; + } + + md.extra = utils::zero(); + + return status::success; +} + +/** returns true if memory desc @p md corresponds to the given format tag and + * strides. + * If strides are not passed (or passed as nullptr) the dense structure is + * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns). + * Strides might contain `0` value, indicating the stride must match the one + * that mkldnn_memory_desc_init_by_tag() returns. + * Strides might contain `-1` values, that would be ignored during the + * comparison. For instance, this can be used if a stride along minibatch + * doesn't matter. */ +inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + if (md.format_kind != types::format_tag_to_kind(tag)) + return false; + + memory_desc_t md_gold; + status_t status = mkldnn_memory_desc_init_by_tag( + &md_gold, md.ndims, md.dims, md.data_type, tag); + if (status != status::success) return false; + + if (md.format_kind != format_kind::blocked) + return false; // unimplemented yet + + const auto &blk = md.format_desc.blocking; + const auto &blk_gold = md_gold.format_desc.blocking; + + using utils::array_cmp; + bool same_blocks = true + && blk.inner_nblks == blk_gold.inner_nblks + && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); + + if (!same_blocks) + return false; + + if (strides == nullptr) + return array_cmp(blk.strides, blk_gold.strides, md.ndims); + + for (int d = 0; d < md.ndims; ++d) { + dim_t stride = strides[d]; + if (stride == -1) continue; + if (stride == 0) stride = blk_gold.strides[d]; + if (blk.strides[d] != stride) return false; + } + + return true; +} + +/** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ +template +format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md, + Tags ...tags) { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(md, tag)) + return tag; + } + return format_tag::undef; +} + +} +} + +#include "memory_desc_wrapper.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp new file mode 100644 index 0000000000..d23f4682dc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp @@ -0,0 +1,135 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#ifdef _WIN32 +#include +#include +#endif +#include +#include +#include + +#include "mkldnn.h" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +int getenv(const char *name, char *buffer, int buffer_size) { + if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0)) + return INT_MIN; + + int result = 0; + int term_zero_idx = 0; + size_t value_length = 0; + +#ifdef _WIN32 + value_length = GetEnvironmentVariable(name, buffer, buffer_size); +#else + const char *value = ::getenv(name); + value_length = value == NULL ? 0 : strlen(value); +#endif + + if (value_length > INT_MAX) + result = INT_MIN; + else { + int int_value_length = (int)value_length; + if (int_value_length >= buffer_size) { + result = -int_value_length; + } else { + term_zero_idx = int_value_length; + result = int_value_length; +#ifndef _WIN32 + strncpy(buffer, value, value_length); +#endif + } + } + + if (buffer != NULL) + buffer[term_zero_idx] = '\0'; + return result; +} + +int getenv_int(const char *name, int default_value) +{ + int value = default_value; + // # of digits in the longest 32-bit signed int + sign + terminating null + const int len = 12; + char value_str[len]; + if (getenv(name, value_str, len) > 0) + value = atoi(value_str); + return value; +} + +FILE *fopen(const char *filename, const char *mode) { +#ifdef _WIN32 + FILE *fp = NULL; + return ::fopen_s(&fp, filename, mode) ? NULL : fp; +#else + return ::fopen(filename, mode); +#endif +} + +void *malloc(size_t size, int alignment) { + void *ptr; + +#ifdef _WIN32 + ptr = _aligned_malloc(size, alignment); + int rc = ptr ? 0 : -1; +#else + int rc = ::posix_memalign(&ptr, alignment, size); +#endif + + return (rc == 0) ? ptr : 0; +} + +void free(void *p) { +#ifdef _WIN32 + _aligned_free(p); +#else + ::free(p); +#endif +} + +// Atomic operations +int32_t fetch_and_add(int32_t *dst, int32_t val) { +#ifdef _WIN32 + return InterlockedExchangeAdd(reinterpret_cast(dst), val); +#else + return __sync_fetch_and_add(dst, val); +#endif +} + +static int jit_dump_flag = 0; +static bool jit_dump_flag_initialized = false; +bool jit_dump_enabled() { + if (!jit_dump_flag_initialized) { + jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP"); + jit_dump_flag_initialized = true; + } + return jit_dump_flag != 0; +} + +} +} + +mkldnn_status_t mkldnn_set_jit_dump(int enabled) { + using namespace mkldnn::impl::status; + mkldnn::impl::jit_dump_flag = enabled; + mkldnn::impl::jit_dump_flag_initialized = true; + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp new file mode 100644 index 0000000000..d5a8ec5139 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp @@ -0,0 +1,370 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef UTILS_HPP +#define UTILS_HPP + +#include +#include +#include +#include +#include + +#if defined(__x86_64__) || defined(_M_X64) +#define MKLDNN_X86_64 +#endif + +#define MSAN_ENABLED 0 +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) +#undef MSAN_ENABLED +#define MSAN_ENABLED 1 +#include +#endif +#endif + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +// Sanity check for 64 bits +static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only"); + +#define CHECK(f) do { \ + status_t status = f; \ + if (status != status::success) \ + return status; \ +} while (0) + +#define IMPLICATION(cause, effect) (!(cause) || !!(effect)) + +namespace utils { + +/* a bunch of std:: analogues to be compliant with any msvs version + * + * Rationale: msvs c++ (and even some c) headers contain special pragma that + * injects msvs-version check into object files in order to abi-mismatches + * during the static linking. This makes sense if e.g. std:: objects are passed + * through between application and library, which is not the case for mkl-dnn + * (since there is no any c++-rt dependent stuff, ideally...). */ + +/* SFINAE helper -- analogue to std::enable_if */ +template struct enable_if {}; +template struct enable_if { typedef T type; }; + +/* analogue std::conditional */ +template struct conditional {}; +template struct conditional +{ typedef T type; }; +template struct conditional +{ typedef F type; }; + +template struct conditional3 {}; +template +struct conditional3 { typedef T type; }; +template +struct conditional3 { typedef FT type; }; +template +struct conditional3 { typedef FF type; }; + +template struct conditional_v {}; +template struct conditional_v +{ static constexpr U value = t; }; +template struct conditional_v +{ static constexpr U value = f; }; + +template struct remove_reference { typedef T type; }; +template struct remove_reference { typedef T type; }; +template struct remove_reference { typedef T type; }; + +template +inline T&& forward(typename utils::remove_reference::type &t) +{ return static_cast(t); } +template +inline T&& forward(typename utils::remove_reference::type &&t) +{ return static_cast(t); } + +template +inline typename remove_reference::type zero() +{ auto zero = typename remove_reference::type(); return zero; } + +template +inline bool everyone_is(T val, P item) { return val == item; } +template +inline bool everyone_is(T val, P item, Args... item_others) { + return val == item && everyone_is(val, item_others...); +} + +template +constexpr bool one_of(T val, P item) { return val == item; } +template +constexpr bool one_of(T val, P item, Args... item_others) { + return val == item || one_of(val, item_others...); +} + +template +inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); } + +template +inline void array_copy(T *dst, const T *src, size_t size) { + for (size_t i = 0; i < size; ++i) dst[i] = src[i]; +} +template +inline bool array_cmp(const T *a1, const T *a2, size_t size) { + for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false; + return true; +} +template +inline void array_set(T *arr, const U& val, size_t size) { + for (size_t i = 0; i < size; ++i) arr[i] = static_cast(val); +} + +namespace product_impl { +template struct int2type{}; + +template +constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; } + +template +inline T product_impl(const T *arr, int2type) { + return arr[0]*product_impl(arr+1, int2type()); } +} + +template +inline T array_product(const T *arr) { + return product_impl::product_impl(arr, product_impl::int2type()); +} + +template +inline R array_product(const T *arr, size_t size) { + R prod = 1; + for (size_t i = 0; i < size; ++i) prod *= arr[i]; + return prod; +} + +/** sorts an array of values using @p comparator. While sorting the array + * of value, the function permutes an array of @p keys accordingly. + * + * @note The arrays of @p keys can be omitted. In this case the function + * sorts the array of @vals only. + */ +template +inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) { + if (size == 0) return; + + for (size_t i = 0; i < size - 1; ++i) { + bool swapped = false; + + for (size_t j = 0; j < size - i - 1; j++) { + if (comparator(vals[j], vals[j + 1]) > 0) { + nstl::swap(vals[j], vals[j + 1]); + if (keys) nstl::swap(keys[j], keys[j + 1]); + swapped = true; + } + } + + if (swapped == false) break; + } +} + +template +inline typename remove_reference::type div_up(const T a, const U b) { + assert(b); + return (a + b - 1) / b; +} + +template +inline typename remove_reference::type rnd_up(const T a, const U b) { + return div_up(a, b) * b; +} + +template +inline typename remove_reference::type rnd_dn(const T a, const U b) { + return (a / b) * b; +} + +template T *align_ptr(T *ptr, uintptr_t alignment) +{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); } + +template +inline U this_block_size(const T offset, const U max, const V block_size) { + assert(offset < max); + // TODO (Roma): can't use nstl::max() due to circular dependency... we + // need to fix this + const T block_boundary = offset + block_size; + if (block_boundary > max) + return max - offset; + else + return block_size; +} + +template +inline T nd_iterator_init(T start) { return start; } +template +inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) { + start = nd_iterator_init(start, utils::forward(tuple)...); + x = start % X; + return start / X; +} + +inline bool nd_iterator_step() { return true; } +template +inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) { + if (nd_iterator_step(utils::forward(tuple)...) ) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X) +{ + U max_jump = end - cur; + U dim_jump = X - x; + if (dim_jump <= max_jump) { + x = 0; + cur += dim_jump; + return true; + } else { + cur += max_jump; + x += max_jump; + return false; + } +} +template +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X, + Args &&... tuple) +{ + if (nd_iterator_jump(cur, end, utils::forward(tuple)...)) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template +inline T pick(size_t i, const T &x0) { return x0; } +template +inline T pick(size_t i, const T &x0, Args &&... args) { + return i == 0 ? x0 : pick(i - 1, utils::forward(args)...); +} + +template +T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference, + const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) { + switch (prop_kind) { + case prop_kind::forward_inference: return val_fwd_inference; + case prop_kind::forward_training: return val_fwd_training; + case prop_kind::backward_data: return val_bwd_d; + case prop_kind::backward_weights: return val_bwd_w; + default: assert(!"unsupported prop_kind"); + } + return T(); +} + +template +T pick_by_prop_kind(prop_kind_t prop_kind, + const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w) +{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); } + +template +struct array_offset_calculator { + template + array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... } + { + _base_ptr = base; + } + template + inline Telem &operator()(Targs... Fargs) + { + return *(_base_ptr + _offset(1, Fargs...)); + } + +private: + template + inline size_t _offset(size_t const dimension, size_t element) + { + return element; + } + + template + inline size_t _offset(size_t const dimension, size_t theta, size_t element) + { + return element + (_dims[dimension] * theta); + } + + template + inline size_t _offset(size_t const dimension, size_t theta, size_t element, + Targs... Fargs) + { + size_t t_prime = element + (_dims[dimension] * theta); + return _offset(dimension + 1, t_prime, Fargs...); + } + + Telem *_base_ptr; + const int _dims[Tdims]; +}; + +} + +int32_t fetch_and_add(int32_t *dst, int32_t val); +inline void yield_thread() {} + +// Reads an environment variable 'name' and stores its string value in the +// 'buffer' of 'buffer_size' bytes on success. +// +// - Returns the length of the environment variable string value (excluding +// the terminating 0) if it is set and its contents (including the terminating +// 0) can be stored in the 'buffer' without truncation. +// +// - Returns negated length of environment variable string value and writes +// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to +// store the value (including the terminating 0) without truncation. +// +// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment +// variable is not set. +// +// - Returns INT_MIN if the 'name' is NULL. +// +// - Returns INT_MIN if the 'buffer_size' is negative. +// +// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than +// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to +// retrieve the length of the environment variable value string. +// +int getenv(const char *name, char *buffer, int buffer_size); +// Reads an integer from the environment +int getenv_int(const char *name, int default_value = 0); +bool jit_dump_enabled(); +FILE *fopen(const char *filename, const char *mode); + +constexpr int msan_enabled = MSAN_ENABLED; +inline void msan_unpoison(void *ptr, size_t size) { +#if MSAN_ENABLED + __msan_unpoison(ptr, size); +#endif +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp new file mode 100644 index 0000000000..89a57772cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp @@ -0,0 +1,665 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#ifndef _WIN32 +#include +#endif + +#include "mkldnn.h" +#include "mkldnn_version.h" +#include "c_types_map.hpp" +#include "verbose.hpp" +#include "cpu/cpu_isa_traits.hpp" + +#include "batch_normalization_pd.hpp" +#include "pooling_pd.hpp" +#include "concat_pd.hpp" +#include "reorder_pd.hpp" +#include "convolution_pd.hpp" +#include "rnn_pd.hpp" +#include "deconvolution_pd.hpp" +#include "shuffle_pd.hpp" +#include "eltwise_pd.hpp" +#include "softmax_pd.hpp" +#include "inner_product_pd.hpp" +#include "sum_pd.hpp" +#include "lrn_pd.hpp" + +/* MKL-DNN CPU ISA info */ +#define ISA_ANY "No instruction set specific optimizations" +#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)" +#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)" +#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)" +#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512)" +#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions" +#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \ + "AVX512-DL Boost)" +#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions" +#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions" + +namespace mkldnn { +namespace impl { + +static verbose_t verbose; +static bool initialized; +static bool version_printed = false; + +const verbose_t *mkldnn_verbose() { +#if !defined(DISABLE_VERBOSE) + if (!initialized) { + const int len = 2; + char val[len] = {0}; + if (getenv("MKLDNN_VERBOSE", val, len) == 1) + verbose.level = atoi(val); + initialized = true; + } + if (!version_printed && verbose.level > 0) { + printf("mkldnn_verbose,info," + "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n", + mkldnn_version()->major, mkldnn_version()->minor, + mkldnn_version()->patch, mkldnn_version()->hash, + get_isa_info()); + version_printed = true; + } +#else + verbose.level = 0; +#endif + return &verbose; +} + +double get_msec() { +#ifdef _WIN32 + static LARGE_INTEGER frequency; + if (frequency.QuadPart == 0) + QueryPerformanceFrequency(&frequency); + LARGE_INTEGER now; + QueryPerformanceCounter(&now); + return 1e+3 * now.QuadPart / frequency.QuadPart; +#else + struct timeval time; + gettimeofday(&time, NULL); + return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec; +#endif +} + +const char *get_isa_info() { + using namespace mkldnn::impl::cpu; + if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS; + if (mayiuse(avx512_mic)) return AVX512_MIC; + if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI; + if (mayiuse(avx512_core)) return AVX512_CORE; + if (mayiuse(avx512_common)) return AVX512_COMMON; + if (mayiuse(avx2)) return AVX2; + if (mayiuse(avx)) return AVX; + if (mayiuse(sse42)) return SSE42; + return ISA_ANY; +} + +/* init_info section */ +namespace { +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_DAT_LEN 256 +#define MKLDNN_VERBOSE_AUX_LEN 384 +#define MKLDNN_VERBOSE_PRB_LEN 384 + +#define DECL_DAT_AUX_PRB_STRS() \ + int dat_written = 0, aux_written = 0, prb_written = 0; \ + MAYBE_UNUSED((dat_written * aux_written * prb_written)); \ + char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \ + char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \ + char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str) + +#define DFMT "%" PRId64 + +void clear_buf(char *buf, int &written) { + /* TODO: do it better */ + buf[0] = '#'; + buf[1] = '\0'; + written = 1; +} + +#define DPRINT(buf, buf_len, written, ...) do { \ + int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \ + if (l < 0 || written + l > buf_len) { \ + clear_buf(buf, written); \ + } else { \ + written += l; \ + } \ +} while(0) + +// XXX: Outputs strings corresponding to memory formats used for data tensors. +void format_prb_desc_str(char *str, int len, const memory_desc_t *md) { + const auto dims = md->dims; + int written = 0; + if (md->ndims == 1) + DPRINT(str, len, written, + "x" DFMT, dims[0]); + else if (md->ndims == 2) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT, dims[0], dims[1]); + else if (md->ndims == 3) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "iw" DFMT, + dims[0], dims[1], dims[2]); + else if (md->ndims == 4) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3]); + else if (md->ndims == 5) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3], dims[4]); + else + mkldnn_md2dim_str(str, len, md); +} + +void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind, + const char *impl_str, mkldnn_prop_kind_t prop_kind, + const char *data_str, const char *aux_str, const char *prb_str) { + MAYBE_UNUSED(verbose_templ); + int written = 0; + DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s", + mkldnn_prim_kind2str(prim_kind), impl_str, + mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str); +} + +template static void init_info_bnorm(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "flags:%u", s->desc()->flags); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_conv(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->ndims() == 5) { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } else { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_shuffle(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md(); + + if (1) { // data + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "axis:%d group_size:" DFMT, s->axis(), s->group_size()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_eltwise(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_iprod(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_lrn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_mem(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "num:%d", s->n_inputs()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str, + aux_str, prb_str); +} + +template static void init_info_pool(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->is_fwd() ? s->src_md() : s->diff_src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // ws + auto md = s->workspace_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->is_3d()) { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "", + s->MB(), s->C(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } else { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT, + s->MB(), s->C(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_softmax(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_rnn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src layer + auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // src iter + auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_layer + auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_iter + auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bias + auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst layer + auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst iter + auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + alg_kind_t alg_kind = s->cell_kind(); + rnn_direction_t rnn_dir = s->direction(); + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s_%s", mkldnn_alg_kind2str(alg_kind), + mkldnn_rnn_direction2str(rnn_dir)); + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "l" DFMT "t" DFMT "mb" DFMT + "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT, + s->L(), s->T(), s->MB(), + s->SIC(), s->SLC(), s->DIC(), s->DLC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +#undef DPRINT + +#else // !defined(DISABLE_VERBOSE) + +#define DEFINE_STUB(name) \ + template \ + static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \ + { UNUSED(s); UNUSED(buffer); } + +DEFINE_STUB(bnorm); +DEFINE_STUB(conv); +DEFINE_STUB(eltwise); +DEFINE_STUB(iprod); +DEFINE_STUB(lrn); +DEFINE_STUB(mem); +DEFINE_STUB(pool); +DEFINE_STUB(softmax); +DEFINE_STUB(rnn); +DEFINE_STUB(shuffle); +#undef DEFINE_STUB + +#endif // !defined(DISABLE_VERBOSE) +} + +void init_info(batch_normalization_pd_t *s, char *b) +{ init_info_bnorm(s, b); } +void init_info(concat_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(convolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(deconvolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(eltwise_pd_t *s, char *b) +{ init_info_eltwise(s, b); } +void init_info(inner_product_pd_t *s, char *b) +{ init_info_iprod(s, b); } +void init_info(lrn_pd_t *s, char *b) +{ init_info_lrn(s, b); } +void init_info(pooling_pd_t *s, char *b) +{ init_info_pool(s, b); } +void init_info(reorder_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(rnn_pd_t *s, char *b) +{ init_info_rnn(s, b); } +void init_info(shuffle_pd_t *s, char *b) +{ init_info_shuffle(s, b); } +void init_info(softmax_pd_t *s, char *b) +{ init_info_softmax(s, b); } +void init_info(sum_pd_t *s, char *b) +{ init_info_mem(s, b); } + +} +} + +mkldnn_status_t mkldnn_set_verbose(int level) { + using namespace mkldnn::impl::status; + if (level < 0 || level > 2) return invalid_arguments; + mkldnn::impl::verbose.level = level; + mkldnn::impl::initialized = true; + return success; +} + +const mkldnn_version_t *mkldnn_version() { + static mkldnn_version_t ver = { + MKLDNN_VERSION_MAJOR, + MKLDNN_VERSION_MINOR, + MKLDNN_VERSION_PATCH, + MKLDNN_VERSION_HASH}; + return &ver; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp new file mode 100644 index 0000000000..e3049750cb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef VERBOSE_HPP +#define VERBOSE_HPP + +#include +#include + +#include "mkldnn_debug.h" +#include "c_types_map.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +struct verbose_t { + int level; +}; + +const verbose_t *mkldnn_verbose(); +double get_msec(); +const char *get_isa_info(); + +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_BUF_LEN 1024 +#else +#define MKLDNN_VERBOSE_BUF_LEN 1 +#endif + +void init_info(batch_normalization_pd_t *s, char *buffer); +void init_info(concat_pd_t *s, char *buffer); +void init_info(convolution_pd_t *s, char *buffer); +void init_info(deconvolution_pd_t *s, char *buffer); +void init_info(eltwise_pd_t *s, char *buffer); +void init_info(inner_product_pd_t *s, char *buffer); +void init_info(lrn_pd_t *s, char *buffer); +void init_info(pooling_pd_t *s, char *buffer); +void init_info(reorder_pd_t *s, char *buffer); +void init_info(rnn_pd_t *s, char *buffer); +void init_info(shuffle_pd_t *s, char *buffer); +void init_info(softmax_pd_t *s, char *buffer); +void init_info(sum_pd_t *s, char *buffer); + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp new file mode 100644 index 0000000000..520bd4710b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef Z_MAGIC_HPP +#define Z_MAGIC_HPP + +#define CHAIn2(a,b) a b +#define CHAIN2(a,b) CHAIn2(a,b) + +#define CONCAt2(a,b) a ## b +#define CONCAT2(a,b) CONCAt2(a,b) + +#define STRINGIFy(s) #s +#define STRINGIFY(s) STRINGIFy(s) + +#ifdef _MSC_VER +# define PRAGMA_MACRo(x) __pragma(x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#else +# define PRAGMA_MACRo(x) _Pragma(#x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#endif + +#define UNUSED(x) ((void)x) +#define MAYBE_UNUSED(x) UNUSED(x) + +#if defined(_WIN32) && !defined(__GNUC__) +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp new file mode 100644 index 0000000000..7cf7822d90 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr) { +# define BAR_CTR_OFF offsetof(ctx_t, ctr) +# define BAR_SENSE_OFF offsetof(ctx_t, sense) + using namespace Xbyak; + + Xbyak::Reg64 reg_tmp = [&]() { + /* returns register which is neither reg_ctx nor reg_nthr */ + Xbyak::Reg64 regs[] = { util::rax, util::rbx, util::rcx }; + for (size_t i = 0; i < sizeof(regs) / sizeof(regs[0]); ++i) + if (!utils::one_of(regs[i], reg_ctx, reg_nthr)) + return regs[i]; + return regs[0]; /* should not happen */ + }(); + + Label barrier_exit_label, barrier_exit_restore_label, spin_label; + + code.cmp(reg_nthr, 1); + code.jbe(barrier_exit_label); + + code.push(reg_tmp); + + /* take and save current sense */ + code.mov(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.push(reg_tmp); + code.mov(reg_tmp, 1); + + if (mayiuse(avx512_mic)) { + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + } + + code.lock(); code.xadd(code.ptr[reg_ctx + BAR_CTR_OFF], reg_tmp); + code.add(reg_tmp, 1); + code.cmp(reg_tmp, reg_nthr); + code.pop(reg_tmp); /* restore previous sense */ + code.jne(spin_label); + + /* the last thread {{{ */ + code.mov(code.qword[reg_ctx + BAR_CTR_OFF], 0); // reset ctx + + // notify waiting threads + code.not_(reg_tmp); + code.mov(code.ptr[reg_ctx + BAR_SENSE_OFF], reg_tmp); + code.jmp(barrier_exit_restore_label); + /* }}} the last thread */ + + code.CodeGenerator::L(spin_label); + code.pause(); + code.cmp(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.je(spin_label); + + code.CodeGenerator::L(barrier_exit_restore_label); + code.pop(reg_tmp); + + code.CodeGenerator::L(barrier_exit_label); +# undef BAR_CTR_OFF +# undef BAR_SENSE_OFF +} + +/** jit barrier generator */ +struct jit_t: public jit_generator { + void (*barrier)(ctx_t *ctx, size_t nthr); + + jit_t() { + generate(*this, abi_param1, abi_param2); + ret(); + barrier = reinterpret_cast(const_cast( + this->getCode())); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t) +}; + +void barrier(ctx_t *ctx, int nthr) { + static jit_t j; /* XXX: constructed on load ... */ + j.barrier(ctx, nthr); +} + +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp new file mode 100644 index 0000000000..0f55e33aa8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BARRIER_HPP +#define CPU_BARRIER_HPP + +#include + +#include "jit_generator.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +STRUCT_ALIGN(64, +struct ctx_t { + enum { CACHE_LINE_SIZE = 64 }; + volatile size_t ctr; + char pad1[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; + volatile size_t sense; + char pad2[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; +}); + +inline void ctx_init(ctx_t *ctx) { *ctx = utils::zero(); } +void barrier(ctx_t *ctx, int nthr); + +/** injects actual barrier implementation into another jitted code + * @params: + * code -- jit_generator object where the barrier is to be injected + * reg_ctx -- read-only register with pointer to the barrier context + * reg_nnthr -- read-only register with the # of synchronizing threads + */ +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr); + +} + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp new file mode 100644 index 0000000000..1ed5ad57b9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_PD_HPP +#define CPU_BATCH_NORMALIZATION_PD_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_batch_normalization_fwd_pd_t: public batch_normalization_fwd_pd_t { + using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t; +}; + +struct cpu_batch_normalization_bwd_pd_t: public batch_normalization_bwd_pd_t { + using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp new file mode 100644 index 0000000000..b8d5c4fcaf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "cpu_batch_normalization_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters) { + int nthrs = mkldnn_get_max_threads(); + int l3_size = get_cache_size(3, true) * nthrs / 2; + + C_blks_per_iter = l3_size / working_set_size; + + if (C_blks_per_iter == 0) + C_blks_per_iter = 1; + if (C_blks_per_iter > C_blks) + C_blks_per_iter = C_blks; + + iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter; +} + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e) { + if (nthr <= C_blks || !mkldnn_thr_syncable()) { + C_ithr = ithr; C_nthr = nthr; + N_ithr = 0; N_nthr = 1; + S_ithr = 0; S_nthr = 1; + N_s = 0; N_e = N; S_s = 0; S_e = SP; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + } else { + if (do_blocking) { + N_nthr = (int)nstl::min(N, nthr); + C_nthr = (int)nstl::min(C_blks, nthr / N_nthr); + S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); + } else { + C_nthr = (int)math::gcd((dim_t)nthr, C_blks); + N_nthr = (int)nstl::min(N, nthr / C_nthr); + S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); + } + + if (!spatial_thr_allowed) + S_nthr = 1; + + if (S_nthr < 1) S_nthr = 1; + if (ithr < C_nthr * N_nthr * S_nthr) { + N_ithr = (ithr / S_nthr) % N_nthr ; + C_ithr = ithr / (N_nthr * S_nthr); + S_ithr = ithr % S_nthr; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + balance211(N, N_nthr, N_ithr, N_s, N_e); + balance211(SP, S_nthr, S_ithr, S_s, S_e); + } else { + S_ithr = N_ithr = C_ithr = -ithr; + S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1; + } + } + + // spatial_thr_allowed is meant to help maintain + // consistent decisions about spatial threading + // between mutiple invocations of this routine. + // It is caller's responsibility to check the + // return value and pass it as a flag to the + // next call if needed. + if (S_nthr == 1) + spatial_thr_allowed = false; + + return spatial_thr_allowed; +} + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size) { + if (!mkldnn_thr_syncable()) return false; + + dim_t nthr = mkldnn_get_max_threads(); + dim_t SP = bdesc->W() * bdesc->D() * bdesc->H(); + dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md()) + .padded_dims()[1]; + assert(C_PADDED % simd_w == 0); + + size_t data = bdesc->MB() * C_PADDED * SP * data_size; + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0); + dim_t C_blks_per_iter{ 1 }, iters{ 1 }; + dim_t C_blks = C_PADDED / simd_w; + + if (do_blocking) { + int num_tensors = bdesc->is_fwd() ? 1 : 2; + size_t working_set_size + = (bdesc->MB() * SP * simd_w * data_size) * num_tensors; + cache_balance(working_set_size, C_blks, C_blks_per_iter, iters); + } + + // Spatial threading decision made in this function shall be consistent + // with thread_balance() behavior. + C_blks = do_blocking ? C_blks_per_iter : C_blks; + + if (nthr <= C_blks) return false; + + dim_t S_nthr = 1; + if (do_blocking) { + dim_t N_nthr = nstl::min(bdesc->MB(), nthr); + dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } else { + dim_t C_nthr = math::gcd(nthr, C_blks); + dim_t N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } + + return S_nthr > 1; +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp new file mode 100644 index 0000000000..0daef0716c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp @@ -0,0 +1,43 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_UTILS_HPP +#define CPU_BATCH_NORMALIZATION_UTILS_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters); + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e); + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size); + +} +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp new file mode 100644 index 0000000000..b926491202 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_concat.hpp" +#include "cpu/simple_concat.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using cpd_create_f = mkldnn::impl::engine_t::concat_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const cpd_create_f cpu_concat_impl_list[] = { + /* + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(ref_concat_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const cpd_create_f *cpu_engine_t::get_concat_implementation_list() const { + return cpu_concat_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp new file mode 100644 index 0000000000..0b01bcf163 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONCAT_PD_HPP +#define CPU_CONCAT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "concat_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_concat_pd_t: public concat_pd_t { + using concat_pd_t::concat_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp new file mode 100644 index 0000000000..52a38a2294 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONVOLUTION_PD_HPP +#define CPU_CONVOLUTION_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t { + using convolution_fwd_pd_t::convolution_fwd_pd_t; + + bool has_padded_dst() const { + memory_desc_wrapper dst_d(&dst_md_); + return OC() != dst_d.padded_dims()[1]; + } + + bool wants_padded_bias() const { + if (!with_bias()) return false; + return has_padded_dst(); + } + + bool wants_zero_pad_dst(bool jit_impl = true) const { + if (!has_padded_dst()) return false; + const auto &po = attr()->post_ops_; + int idx; + if ((idx = po.find(primitive_kind::eltwise)) == -1) return false; + return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg, + jit_impl); + } +}; + +struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t { + using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t; +}; + +struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t { + using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t; + + bool wants_padded_bias() const { + if (!with_bias()) return false; + memory_desc_wrapper diff_dst_d(&diff_dst_md_); + return OC() != diff_dst_d.padded_dims()[1]; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp new file mode 100644 index 0000000000..164c8601d7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_DECONVOLUTION_PD_HPP +#define CPU_DECONVOLUTION_PD_HPP + +#include + +#include "deconvolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t { + using deconvolution_fwd_pd_t::deconvolution_fwd_pd_t; +}; + +struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t { + using deconvolution_bwd_data_pd_t::deconvolution_bwd_data_pd_t; +}; + +struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t { + using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp new file mode 100644 index 0000000000..c52f00026e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ELTWISE_PD_HPP +#define CPU_ELTWISE_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "eltwise_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_eltwise_fwd_pd_t: public eltwise_fwd_pd_t { + using eltwise_fwd_pd_t::eltwise_fwd_pd_t; +}; + +struct cpu_eltwise_bwd_pd_t: public eltwise_bwd_pd_t { + using eltwise_bwd_pd_t::eltwise_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp new file mode 100644 index 0000000000..ce0a3667ad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp @@ -0,0 +1,324 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "type_helpers.hpp" +#include "verbose.hpp" + +#include "cpu_engine.hpp" +#include "cpu_memory.hpp" + +//#include "cpu/rnn/ref_rnn.hpp" + +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" +//#include "cpu/jit_avx512_common_1x1_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp" +#include "cpu/jit_avx512_common_convolution_winograd.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp" +#include "cpu/jit_avx512_common_convolution.hpp" +//#include "cpu/jit_avx2_1x1_convolution.hpp" +//#include "cpu/jit_sse42_1x1_convolution.hpp" +#include "cpu/jit_avx2_convolution.hpp" +#include "cpu/jit_sse42_convolution.hpp" +//#include "cpu/gemm_convolution.hpp" +//#include "cpu/gemm_x8s8s32x_convolution.hpp" +//#include "cpu/ref_convolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp" +//#include "cpu/ref_deconvolution.hpp" +//#include "cpu/ref_shuffle.hpp" +//#include "cpu/jit_uni_eltwise.hpp" +//#include "cpu/ref_eltwise.hpp" +//#include "cpu/ref_softmax.hpp" +#include "cpu/jit_uni_pooling.hpp" +//#include "cpu/jit_uni_i8i8_pooling.hpp" +//#include "cpu/ref_pooling.hpp" +//#include "cpu/nchw_pooling.hpp" +//#include "cpu/nhwc_pooling.hpp" +//#include "cpu/jit_avx512_common_lrn.hpp" +//#include "cpu/jit_uni_lrn.hpp" +//#include "cpu/ref_lrn.hpp" +//#include "cpu/jit_uni_batch_normalization.hpp" +//#include "cpu/ref_batch_normalization.hpp" +//#include "cpu/ncsp_batch_normalization.hpp" +//#include "cpu/nspc_batch_normalization.hpp" +//#include "cpu/ref_inner_product.hpp" +//#include "cpu/gemm_inner_product.hpp" +//#include "cpu/gemm_x8s8s32x_inner_product.hpp" +//#include "cpu/jit_uni_dw_convolution.hpp" +//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +status_t cpu_engine_t::memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) { + auto _memory = new cpu_memory_t(this, md, handle); + if (_memory == nullptr) + return status::out_of_memory; + + status_t status = _memory->init(); + if (status != status::success) { + delete _memory; + return status; + } + + return safe_ptr_assign(*memory, _memory); +} + +using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; + +#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t> +static const pd_create_f cpu_impl_list[] = { + /* RNN */ + /* + INSTANCE(ref_rnn_fwd_f32_t), + INSTANCE(ref_rnn_fwd_u8s8_t), + INSTANCE(ref_rnn_bwd_f32_t), + */ + /* conv */ + /* + INSTANCE(jit_avx512_common_dw_convolution_fwd_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t), + */ + INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t), + INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_winograd_fwd_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_fwd_t), + //INSTANCE(jit_avx512_common_convolution_bwd_data_t), + //INSTANCE(jit_avx512_common_convolution_bwd_weights_t), + /* + INSTANCE(jit_avx2_dw_convolution_fwd_t), + INSTANCE(jit_avx2_dw_convolution_bwd_data_t), + INSTANCE(jit_avx2_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx2_1x1_convolution_fwd_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_data_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t), + INSTANCE(jit_sse42_dw_convolution_fwd_t), + INSTANCE(jit_sse42_dw_convolution_bwd_data_t), + INSTANCE(jit_sse42_dw_convolution_bwd_weights_t), + INSTANCE(jit_sse42_1x1_convolution_fwd_t), + */ + INSTANCE(jit_avx2_convolution_fwd_t), + //INSTANCE(jit_avx2_convolution_bwd_data_t), + //INSTANCE(jit_avx2_convolution_bwd_weights_t), + INSTANCE(jit_sse42_convolution_fwd_t), + /* + INSTANCE(gemm_convolution_fwd_t), + INSTANCE(gemm_convolution_bwd_data_t), + INSTANCE(gemm_convolution_bwd_weights_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_weights_t), + */ + /* conv (int) */ + /* + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + */ + /* deconv */ + /* + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(ref_deconvolution_bwd_weights_t), + INSTANCE(ref_deconvolution_bwd_data_t), + INSTANCE(ref_deconvolution_fwd_t), + */ + /* shuffle */ + /* + INSTANCE(ref_shuffle_t<4>), // f32 or s32 + INSTANCE(ref_shuffle_t<1>), // s8 or u8 + */ + /* eltwise */ + /* + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_bwd_t), + */ + /* eltwise (int) */ + /* + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_bwd_t), + */ + /* softmax */ + /* + INSTANCE(ref_softmax_fwd_t), + INSTANCE(ref_softmax_bwd_t), + */ + /* pool */ + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + /* + INSTANCE(nchw_pooling_fwd_t), + INSTANCE(nchw_pooling_bwd_t), + INSTANCE(nhwc_pooling_fwd_t), + INSTANCE(nhwc_pooling_bwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_bwd_t), + */ + /* pool (int) */ + /* + INSTANCE(jit_uni_i8i8_pooling_fwd_t), + INSTANCE(jit_uni_i8i8_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_bwd_t), + */ + /* lrn */ + /* + INSTANCE(jit_avx512_common_lrn_fwd_t), + INSTANCE(jit_avx512_common_lrn_bwd_t), + INSTANCE(jit_uni_lrn_fwd_t), + INSTANCE(jit_uni_lrn_bwd_t), + INSTANCE(jit_uni_lrn_fwd_t), + INSTANCE(ref_lrn_fwd_t), + INSTANCE(ref_lrn_bwd_t), + */ + /* batch normalization */ + /* + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(ncsp_batch_normalization_fwd_t), + INSTANCE(ncsp_batch_normalization_bwd_t), + INSTANCE(nspc_batch_normalization_fwd_t), + INSTANCE(nspc_batch_normalization_bwd_t), + INSTANCE(ref_batch_normalization_fwd_t), + INSTANCE(ref_batch_normalization_bwd_t), + INSTANCE(ref_batch_normalization_fwd_t), + */ + /* inner product */ + /* + INSTANCE(gemm_inner_product_fwd_t), + INSTANCE(gemm_inner_product_bwd_data_t), + INSTANCE(gemm_inner_product_bwd_weights_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_bwd_data_t), + INSTANCE(ref_inner_product_bwd_weights_t), + */ + /* inner product (int) */ + /* + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + */ + /* eol */ + nullptr, +}; +#undef INSTANCE +} + +const pd_create_f* cpu_engine_t::get_implementation_list() const { + return cpu_impl_list; +} + +cpu_engine_factory_t engine_factory; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp new file mode 100644 index 0000000000..e4c877ee05 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp @@ -0,0 +1,70 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ENGINE_HPP +#define CPU_ENGINE_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "../common/engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class cpu_engine_t: public engine_t { +public: + cpu_engine_t(): engine_t(engine_kind::cpu) {} + + /* implementation part */ + + virtual status_t memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) override; + + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const override; + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const override; + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const override; + virtual const primitive_desc_create_f* + get_implementation_list() const override; +}; + +class cpu_engine_factory_t: public engine_factory_t { +public: + virtual size_t count() const override { return 1; } + virtual engine_kind_t kind() const override { return engine_kind::cpu; } + virtual status_t engine_create(engine_t **engine, + size_t index) const override { + assert(index == 0); + *engine = new cpu_engine_t(); + return status::success; + }; +}; + +extern cpu_engine_factory_t engine_factory; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp new file mode 100644 index 0000000000..5880d3450c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_INNER_PRODUCT_PD_HPP +#define CPU_INNER_PRODUCT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "inner_product_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { +inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { + using namespace utils; + + auto strides_compatible = [&]() { + bool ok = true; + auto w_str = wei_d.blocking_desc().strides; + auto d_str = src_d.blocking_desc().strides; + for (int i = 1; i < src_d.ndims() - 1; i++) { + ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1]; + } + return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]); + }; + return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc() + && src_d.ndims() == wei_d.ndims() + && src_d.blocking_desc().inner_nblks + == wei_d.blocking_desc().inner_nblks + && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1) + && array_cmp(src_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_nblks) + && array_cmp(src_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_nblks) + && strides_compatible() + && dst_d.matches_tag(format_tag::nc) + && src_d.only_padded_dim(1) + && wei_d.only_padded_dim(1) + && src_d.padded_dims()[1] == wei_d.padded_dims()[1] + && src_d.is_dense(true) + && dst_d.is_dense() + && wei_d.is_dense(true); +} +} + +struct cpu_inner_product_fwd_pd_t: public inner_product_fwd_pd_t { + using inner_product_fwd_pd_t::inner_product_fwd_pd_t; +}; + +struct cpu_inner_product_bwd_data_pd_t: public inner_product_bwd_data_pd_t { + using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t; +}; + +struct cpu_inner_product_bwd_weights_pd_t: public inner_product_bwd_weights_pd_t { + using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp new file mode 100644 index 0000000000..da6e9dac8e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp @@ -0,0 +1,151 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ISA_TRAITS_HPP +#define CPU_ISA_TRAITS_HPP + +#include + +#define XBYAK64 +#define XBYAK_NO_OP_NAMES +/* in order to make selinux happy memory that would be marked with X-bit should + * be obtained with mmap */ +#define XBYAK_USE_MMAP_ALLOCATOR +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +/* turn off `size_t to other-type implicit casting` warning + * currently we have a lot of jit-generated instructions that + * take uint32_t, but we pass size_t (e.g. due to using sizeof). + * FIXME: replace size_t parameters with the appropriate ones */ +#pragma warning (disable: 4267) +#endif +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + isa_any, + sse41, + sse42, + avx, + avx2, + avx512_common, + avx512_core, + avx512_core_vnni, + avx512_mic, + avx512_mic_4ops, +} cpu_isa_t; + +template struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */ + +template <> struct cpu_isa_traits { + typedef Xbyak::Xmm Vmm; + static constexpr int vlen_shift = 4; + static constexpr int vlen = 16; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits { + typedef Xbyak::Ymm Vmm; + static constexpr int vlen_shift = 5; + static constexpr int vlen = 32; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits { + typedef Xbyak::Zmm Vmm; + static constexpr int vlen_shift = 6; + static constexpr int vlen = 64; + static constexpr int n_vregs = 32; +}; +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +namespace { + +static Xbyak::util::Cpu cpu; +static inline bool mayiuse(const cpu_isa_t cpu_isa) { + using namespace Xbyak::util; + + switch (cpu_isa) { + case sse41: + case sse42: + // FIXME: SSE4.2 is actually NOT required + //return cpu.has(Cpu::tSSE42); + return cpu.has(Cpu::tSSE41); + case avx: + return cpu.has(Cpu::tAVX); + case avx2: + return cpu.has(Cpu::tAVX2); + case avx512_common: + return cpu.has(Cpu::tAVX512F); + case avx512_core: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ); + case avx512_core_vnni: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ) + && cpu.has(Cpu::tAVX512_VNNI); + case avx512_mic: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512CD) + && cpu.has(Cpu::tAVX512ER) + && cpu.has(Cpu::tAVX512PF); + case avx512_mic_4ops: + return true + && mayiuse(avx512_mic) + && cpu.has(Cpu::tAVX512_4FMAPS) + && cpu.has(Cpu::tAVX512_4VNNIW); + case isa_any: + return true; + } + return false; +} +} + +/* whatever is required to generate string literals... */ +#include "z_magic.hpp" +#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \ + (isa == sse42 ? prefix STRINGIFY(sse42) : \ + (isa == avx ? prefix STRINGIFY(avx) : \ + (isa == avx2 ? prefix STRINGIFY(avx2) : \ + (isa == avx512_common ? prefix STRINGIFY(avx512_common) : \ + (isa == avx512_core ? prefix STRINGIFY(avx512_core) : \ + (isa == avx512_mic ? prefix STRINGIFY(avx512_mic) : \ + (isa == avx512_mic_4ops ? prefix STRINGIFY(avx512_mic_4ops) : \ + prefix suffix_if_any))))))) + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp new file mode 100644 index 0000000000..49988f4c2d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_LRN_PD_HPP +#define CPU_LRN_PD_HPP + +#include + +#include "lrn_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_lrn_fwd_pd_t: public lrn_fwd_pd_t { + using lrn_fwd_pd_t::lrn_fwd_pd_t; +}; + +struct cpu_lrn_bwd_pd_t: public lrn_bwd_pd_t { + using lrn_bwd_pd_t::lrn_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp new file mode 100644 index 0000000000..3c0624cf46 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn_traits.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; + +enum blk_kind_t { a, b, c, ab, ba, bc, cb }; + +template +void typed_zero_pad_blk( + const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { + using data_t = typename prec_traits
::type; + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + const auto &blk = m_d.blocking_desc(); + auto dim_is_blocked = [&](int dim) { + for (int i = 0; i < blk.inner_nblks; i++) + if (blk.inner_idxs[i] == dim) + return true; + return false; + }; + bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1), + C_blocked = dim_is_blocked(2); + + assert(blk.inner_nblks < 4); + assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked) + || (C_blocked && B_blocked)); + + const int a_tail_s = A_blocked ? dims[0] % blksize : 0; + const int b_tail_s = B_blocked ? dims[1] % blksize : 0; + const int c_tail_s = C_blocked ? dims[2] % blksize : 0; + assert(a_tail_s || b_tail_s || c_tail_s); + + const int A = A_blocked ? pdims[0] / blksize : dims[0]; + const int B = B_blocked ? pdims[1] / blksize : dims[1]; + const int C = C_blocked ? pdims[2] / blksize : dims[2]; + const int D = m_d.ndims() > 3 ? dims[3] : 1; + const int E = m_d.ndims() > 4 ? dims[4] : 1; + const int F = m_d.ndims() > 5 ? dims[5] : 1; + const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1; + + auto zeroize_tail = [&](data_t *d, const int tail_s) { + for (int b = tail_s; b < blksize; ++b) + d[b] = 0; + }; + auto zeroize_tail_inner = [&](data_t *d, const int tail_s) { + for (int b1 = 0; b1 < blksize; ++b1) + for (int b2 = tail_s; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + auto zeroize_tail_outer = [&](data_t *d, const int tail_s) { + for (int b1 = tail_s; b1 < blksize; ++b1) + for (int b2 = 0; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + + if (c_tail_s) { + parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)]; + if (blk_kind == c) + zeroize_tail(x, c_tail_s); + else if (blk_kind == bc) + zeroize_tail_inner(x, c_tail_s); + else if (blk_kind == cb) + zeroize_tail_outer(x, c_tail_s); + }); + } + + if (b_tail_s) { + parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)]; + if (blk_kind == b) + zeroize_tail(x, b_tail_s); + else if (blk_kind == ab || blk_kind == cb) + zeroize_tail_inner(x, b_tail_s); + else if (blk_kind == ba || blk_kind == bc) + zeroize_tail_outer(x, b_tail_s); + }); + } + + if (a_tail_s) { + parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)]; + if (blk_kind == a) + zeroize_tail(x, a_tail_s); + else if (blk_kind == ba) + zeroize_tail_inner(x, a_tail_s); + else if (blk_kind == ab) + zeroize_tail_outer(x, a_tail_s); + }); + } +} + +/* + * all + */ +template +void typed_zero_pad_generic_blocked( + const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { + const int ndims = m_d.ndims(); + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + + const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true); + + /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1] + * | \ / + * | --------------------- + * has contiguous + * padding + * + * step <-- D_k+1 * ... * D_ndims-1 + * step_dim <-- k + */ + + ptrdiff_t step = 1; + int step_dim = ndims - 1; + for (; step_dim >= 0; --step_dim) { + if (dims[step_dim] != pdims[step_dim]) + break; + step *= dims[step_dim]; + } + + assert(step_dim >= 0 && "no zero padding is required"); + if (step_dim < 0) + return; + + parallel_nd(nelems / step, [&](ptrdiff_t e1) { + bool need_zero = false; + + ptrdiff_t idx = e1; + for (int d = step_dim; d >= 0; --d) { + if (idx % pdims[d] >= dims[d]) { + need_zero = true; + break; + } + idx /= pdims[d]; + } + + if (need_zero) { + for (ptrdiff_t e0 = 0; e0 < step; ++e0) + data[m_d.off_l(e1 * step + e0, true)] = 0; + } + }); +} + +template +status_t cpu_memory_t::typed_zero_pad() const { + const memory_desc_wrapper mdw(md()); + + if (mdw.format_kind() != format_kind::blocked) + return unimplemented; + + if (mdw.nelems(false) == mdw.nelems(true)) + return success; + + auto *data = (typename prec_traits
::type *)data_; + auto blk = mdw.blocking_desc(); + + auto get_blksize = [&](int ind) { + int blksize = 1; + for (int i = 0; i < blk.inner_nblks; i++) { + if (blk.inner_idxs[i] == ind) + blksize *= blk.inner_blks[i]; + } + return blksize; + }; + const int blksize = get_blksize(blk.inner_idxs[0]); + +# define CASE(blksize_, blk_kind) \ + do { \ + if (blksize == blksize_) { \ + typed_zero_pad_blk(mdw, data); \ + return success; \ + } \ + } while(0) + + switch (blk.inner_nblks) { + case 1: + if (blk.inner_idxs[0] == 0) { + CASE(4, a); + CASE(8, a); + CASE(16, a); + } else if (blk.inner_idxs[0] == 1) { + CASE(4, b); + CASE(8, b); + CASE(16, b); + } + break; + case 2: + case 3: + if (!IMPLICATION(blk.inner_nblks == 3, + blk.inner_idxs[0] == blk.inner_idxs[2])) + break; + + if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) { + CASE(4, ab); + CASE(8, ab); + CASE(16, ab); + } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) { + CASE(4, ba); + CASE(8, ba); + CASE(16, ba); + } + if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) { + CASE(4, bc); + CASE(8, bc); + CASE(16, bc); + } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) { + CASE(4, cb); + CASE(8, cb); + CASE(16, cb); + } + break; + default: break; + } + +# undef CASE + + // the last line of defence + typed_zero_pad_generic_blocked
(mdw, data); + return success; +} + +status_t cpu_memory_t::zero_pad() const { + memory_desc_wrapper mdw(md()); + const bool skip_zeroing = false + || data_ == nullptr + || mdw.is_zero() + || !mdw.is_blocking_desc(); + if (skip_zeroing) return success; + + switch (mdw.data_type()) { + case f32: return typed_zero_pad(); + case s32: return typed_zero_pad(); + case s8: return typed_zero_pad(); + case u8: return typed_zero_pad(); + default: assert(!"memory is undefined"); return unimplemented; + } + return unimplemented; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp new file mode 100644 index 0000000000..2c01bcc6af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_MEMORY_HPP +#define CPU_MEMORY_HPP + +#include + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "memory_desc_wrapper.hpp" + +#include "cpu_engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t: public memory_t { + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md, void *handle) + : memory_t(engine, md) + , own_data_(handle == MKLDNN_NATIVE_HANDLE_ALLOCATE) + , data_((char *)handle) {} + + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md) + : cpu_memory_t(engine, md, nullptr) {} + + ~cpu_memory_t() { if (own_data_) free(data_); } + + virtual status_t init() override { + if (own_data_) { + data_ = nullptr; + const size_t size = memory_desc_wrapper(this->md()).size(); + if (size) { + data_ = (char *)malloc(size, 64); + if (data_ == nullptr) + return status::out_of_memory; + } + } + return zero_pad(); + } + + cpu_engine_t *engine() const { return (cpu_engine_t *)memory_t::engine(); } + + virtual status_t get_data_handle(void **handle) const override { + *handle = static_cast(data_); + return status::success; + } + + virtual mkldnn::impl::status_t set_data_handle(void *handle) override { + if (own_data_) { free(data_); own_data_ = false; } + data_ = static_cast(handle); + return zero_pad(); + } + + virtual mkldnn::impl::status_t zero_pad() const override; + +private: + bool own_data_; + char *data_; + + template + mkldnn::impl::status_t typed_zero_pad() const; + + cpu_memory_t(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(cpu_memory_t &&) = delete; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp new file mode 100644 index 0000000000..ac2daa415e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_POOLING_PD_HPP +#define CPU_POOLING_PD_HPP + +#include "pooling_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t { + using pooling_fwd_pd_t::pooling_fwd_pd_t; +}; + +struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t { + using pooling_bwd_pd_t::pooling_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp new file mode 100644 index 0000000000..56127f36c2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp @@ -0,0 +1,83 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PRIMITIVE_HPP +#define CPU_PRIMITIVE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "primitive.hpp" +#include "scratchpad.hpp" + +#define CTX_IN_MEM(type, arg) static_cast(ctx.input(arg)) +#define CTX_OUT_MEM(type, arg) static_cast(ctx.output(arg)) + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t; + +struct cpu_primitive_t: public primitive_t { + cpu_primitive_t(const primitive_desc_t *pd, + bool use_global_scratchpad = false) + : primitive_t(pd) + , scratchpad_buffer_(nullptr) + , global_scratchpad_(nullptr) + { + const size_t scratchpad_size = + this->pd()->scratchpad_size(scratchpad_mode::library); + + if (scratchpad_size) { + if (use_global_scratchpad) + global_scratchpad_ = create_scratchpad(scratchpad_size); + else + scratchpad_buffer_ = malloc(scratchpad_size, 64); + } + } + + virtual ~cpu_primitive_t() { + delete global_scratchpad_; + free(scratchpad_buffer_); + } + +protected: + memory_tracking::grantor_t scratchpad(const exec_ctx_t &ctx) const { + void *ptr = nullptr; + if (pd()->attr()->scratchpad_mode_ == scratchpad_mode::user) { + ptr = CTX_OUT_MEM(void *, MKLDNN_ARG_SCRATCHPAD); + } else { + ptr = global_scratchpad_ + ? global_scratchpad_->get() : scratchpad_buffer_; + } + + return pd()->scratchpad_registry().grantor(ptr); + } + +private: + void *scratchpad_buffer_; + scratchpad_t *global_scratchpad_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp new file mode 100644 index 0000000000..1d41ac5cea --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp @@ -0,0 +1,544 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "cpu_reducer.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void reduce_balancer_t::balance() { + using namespace nstl; + using namespace utils; + + assert(nthr_ > 0 && job_size_ > 0 && njobs_ > 0 && reduction_size_ > 0); + + const int job_complexity = 1; + + const int min_njobs_per_group = max(1, njobs_ / nthr_); + const int max_njobs_per_group = max(1, + static_cast(max_buffer_size_ / (nthr_ * job_size_))); + + /* initial guess */ + int ngroups = min(njobs_ / min_njobs_per_group, nthr_); + int nthr_per_group = syncable_ ? min(nthr_ / ngroups, reduction_size_) : 1; + int njobs_per_group_ub = div_up(njobs_, ngroups); + + /* rough upper-bound estimation, will be fixed during brute force */ + size_t thread_complexity_ub = njobs_ * job_size_ * reduction_size_; + + /* brute force parameters for the best balance... */ + for (int c_njobs_per_group = min_njobs_per_group; + c_njobs_per_group < njobs_; ++c_njobs_per_group) { + /* current assumption */ + int c_ngroups = min(njobs_ / c_njobs_per_group, nthr_); + int c_nthr_per_group = syncable_ + ? min(nthr_ / c_ngroups, reduction_size_) : 1; + int c_njobs_per_group_ub = div_up(njobs_, c_ngroups); + + if (c_nthr_per_group > 1 && c_njobs_per_group_ub > max_njobs_per_group) + continue; + + int c_thread_reduction_ub = div_up(reduction_size_, c_nthr_per_group); + size_t c_group_size_ub = job_size_ * c_njobs_per_group_ub; + size_t c_thread_complexity_ub = c_group_size_ub * ( + job_complexity * c_thread_reduction_ub + + (c_nthr_per_group != 1)); + + if (c_thread_complexity_ub < thread_complexity_ub) { + ngroups = c_ngroups; + nthr_per_group = c_nthr_per_group; + njobs_per_group_ub = c_njobs_per_group_ub; + thread_complexity_ub = c_thread_complexity_ub; + } + } + + assert(njobs_per_group_ub <= max_njobs_per_group || nthr_per_group == 1); + assert(ngroups * nthr_per_group <= nthr_); + assert((size_t)njobs_per_group_ub * job_size_ * nthr_ <= max_buffer_size_ + || nthr_per_group == 1); /* no reduction buffer overflow */ + assert(IMPLICATION(!syncable_, nthr_per_group == 1)); + + ngroups_ = ngroups; + nthr_per_group_ = nthr_per_group; + njobs_per_group_ub_ = njobs_per_group_ub; +} + +/* reducer jit-ted driver */ + +using namespace Xbyak; + +template +struct reducer_2d_driver_t: public c_compatible { + typedef typename prec_traits::type data_t; + + reducer_2d_driver_t(int n_src, size_t src_ld, + size_t src_step, size_t dst_step, bool nullify_dst) + : n_src_(n_src), src_ld_(src_ld), src_step_(src_step) + , dst_step_(dst_step), nullify_dst_(nullify_dst), ker_(nullptr) {} + virtual ~reducer_2d_driver_t() {} + void operator()(data_t *dst, const data_t *srcs, size_t ny, size_t nx) + { assert(ker_); ker_(dst, srcs, ny, nx); } + +protected: + int n_src_; + size_t src_ld_, src_step_, dst_step_; + bool nullify_dst_; + void (*ker_)(data_t *dst, const data_t *srcs, size_t ny, size_t nx); +}; + +template +struct reducer_2d_driver_f_s_32_t: public reducer_2d_driver_t, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + const AddressFrame &vmmword = (isa == avx2) ? yword : zword; + void uni_vadd(const Xmm& x1, const Xmm& x2, const Operand& op) + { if (data_type == data_type::f32) vaddps(x1, x2, op); + else vpaddd(x1, x2, op); } + void uni_add(const Xmm& x1, const Operand& op) + { if (data_type == data_type::f32) addss(x1, op); else paddd(x1, op); } + + const int vlen = cpu_isa_traits::vlen; + const int typesize + = sizeof(typename mkldnn::impl::prec_traits::type); + Xbyak::Reg64 reg_dst = abi_param1; + Xbyak::Reg64 reg_src = abi_param2; + Xbyak::Reg64 reg_ny = abi_param3; + Xbyak::Reg64 reg_nx = abi_param4; + + Xbyak::Reg64 reg_x = rax; + Xbyak::Reg64 reg_src_id = r10; + + reducer_2d_driver_f_s_32_t(int n_src, size_t src_ld, size_t src_step, + size_t dst_step, bool nullify_dst) + : reducer_2d_driver_t(n_src, src_ld, src_step, + dst_step, nullify_dst) + { generate(); } + + void nullify_dst(int nloads, int load_len) { + UNUSED(load_len); + for (int i = 0; i < nloads; ++i) + uni_vpxor(Vmm(i), Vmm(i), Vmm(i)); + /* prefetches[dst] ? */ + } + + void load_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(Xmm(i), ptr[reg_dst + i * load_len]); + else if (load_len == vlen) + vmovups(Vmm(i), ptr[reg_dst + i * load_len]); + else + assert(!"unsupported"); + } + } + + void store_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(ptr[reg_dst + i * load_len], Xmm(i)); + else if (load_len == vlen) + vmovups(ptr[reg_dst + i * load_len], Vmm(i)); + else + assert(!"unsupported"); + } + } + + void accumulate(int nloads, int load_len, size_t base_off) { + for (int i = 0; i < nloads; ++i) { + size_t off = base_off + i * load_len; + + if (load_len == typesize) + uni_add(Xmm(i), ptr[reg_src + off]); + else if (load_len == vlen) + uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); + else + assert(!"unsupported"); + } + } + + void loop_x() { + const int nloads[] = {cpu_isa_traits::n_vregs, 1, 1}; + const int nbranches = sizeof(nloads) / sizeof(nloads[0]); + + const int load_len[nbranches] = {vlen, vlen, typesize}; + Label loop_x_label[nbranches + 1]; + + mov(reg_x, reg_nx); + + for (int id = 0; id < nbranches; ++id) { + L(loop_x_label[id]); + + cmp(reg_x, nloads[id] * load_len[id]); + jl(loop_x_label[id + 1], T_NEAR); + + if (this->nullify_dst_) + nullify_dst(nloads[id], load_len[id]); + else + load_dst(nloads[id], load_len[id]); + + if (nloads[id] > 1) { + Label loop_srcs; + mov(reg_src_id, this->n_src_); + L(loop_srcs); + + accumulate(nloads[id], load_len[id], 0); + add(reg_src, this->src_ld_ * typesize); + + dec(reg_src_id); + jnz(loop_srcs, T_NEAR); + + sub(reg_src, this->n_src_ * this->src_ld_ * typesize); + } else { + for (int src_id = 0; src_id < this->n_src_; ++src_id) { + const size_t base_off = src_id * this->src_ld_ * typesize; + accumulate(nloads[id], load_len[id], base_off); + } + } + + store_dst(nloads[id], load_len[id]); + + add(reg_src, nloads[id] * load_len[id]); + add(reg_dst, nloads[id] * load_len[id]); + + sub(reg_x, nloads[id] * load_len[id]); + + jmp(loop_x_label[id], T_NEAR); + } + + L(loop_x_label[nbranches]); + + /* restore address registers */ + sub(reg_src, reg_nx); + sub(reg_dst, reg_nx); + } + + void generate() { + assert(isa == avx2 || isa == avx512_common || isa == avx512_mic); + + preamble(); + + shl(reg_nx, 2); + + Label ny_loop; + L(ny_loop); + + loop_x(); + + add(reg_dst, this->dst_step_ * typesize); + add(reg_src, this->src_step_ * typesize); + + dec(reg_ny); + jnz(ny_loop, T_NEAR); + + postamble(); + this->ker_ = reinterpret_castker_)>( + const_cast(this->getCode())); + } +}; + +template +inline reducer_2d_driver_t *create_reduce_2d_drv(int n_src, + size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) { + if (mayiuse(avx512_common)) + return new reducer_2d_driver_f_s_32_t(n_src, + src_ld, src_step, dst_step, nullify_dst); + else if (mayiuse(avx2)) + return new reducer_2d_driver_f_s_32_t(n_src, src_ld, + src_step, dst_step, nullify_dst); + assert(!"unimplemented"); + return nullptr; +} + +/* cpu_reducer_t */ + +template +void cpu_reducer_t::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ + * (balancer_.nthr_per_group_ - 1) + * cpu_reducer_t::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size, PAGE_4K); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template +cpu_reducer_t::cpu_reducer_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv(balancer().nthr_per_group_ - 1, + space_per_thread(balancer()), 0, 0, false); +} + +template +cpu_reducer_t::~cpu_reducer_t() { delete drv_; } + +template +typename cpu_reducer_t::data_t * +cpu_reducer_t::get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + + /* threads 0 from each group writes directly to the destination */ + if (id_in_grp == 0) + return dst + balancer().ithr_job_off(ithr) * balancer().job_size_; + + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * (balancer().nthr_per_group_ - 1) + + (id_in_grp - 1); + + auto space = scratchpad.template get(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template +void cpu_reducer_t::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + +#ifdef SIMPLE_IMPL + if (balancer().id_in_group(ithr) != 0) + return; /* only threads 0 do the reduction */ + + const int njobs_in_grp = balancer().ithr_njobs(ithr); + data_t *d = get_local_ptr(ithr, dst, scratchpad); + for (int id_in_grp = 1; id_in_grp < balancer_.nthr_per_group_; ++id_in_grp) + { + const data_t *space = get_local_ptr(ithr + id_in_grp, dst, scratchpad); + for (size_t i = 0; i < (size_t)njobs_in_grp * balancer().job_size_; ++i) + d[i] += space[i]; + } +#else + using namespace utils; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const size_t cl = 64 / sizeof(data_t); + + const size_t reduction_size = njobs_in_grp * balancer().job_size_; + size_t start{0}, end{0}; + balance211(div_up(reduction_size, cl), balancer().nthr_per_group_, + id_in_grp, start, end); + + if (start == end) return; + + data_t *d = get_local_ptr(ithr - id_in_grp, dst, scratchpad) + start * cl; + const data_t *space = get_local_ptr(ithr - id_in_grp + 1, dst, scratchpad) + + start * cl; + const size_t len = nstl::min(end * cl, reduction_size) - start * cl; + + (*drv_)(d, space, 1, len); +#endif +} + +template struct cpu_reducer_t; +template struct cpu_reducer_t; + +/* cpu_reducer_2d_t */ + +template +void cpu_reducer_2d_t::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ * balancer_.nthr_per_group_ + * cpu_reducer_2d_t::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template +cpu_reducer_2d_t::cpu_reducer_2d_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv(balancer().nthr_per_group_, + space_per_thread(balancer()), conf_.job_size_x_, conf_.dst_x_, + true); +} + +template +cpu_reducer_2d_t::~cpu_reducer_2d_t() { delete drv_; } + +template +typename cpu_reducer_2d_t::data_t *cpu_reducer_2d_t:: +get_local_ptr(int ithr, const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * balancer().nthr_per_group_ + id_in_grp; + auto space = scratchpad.template get(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template +int cpu_reducer_2d_t::choose_x_blocking(int nx, int ny, + int nthr_per_grp) const { + // find x_blocking for better balance reducing work between threads + assert(conf_.x_block_ > 0 && nx > conf_.x_block_ + && nx % conf_.x_block_ == 0); + int x_blocking = nx / conf_.x_block_; + int min_x_blocking = + utils::div_up(x_blocking, nstl::max(1, nthr_per_grp / ny)); + while (true) { + if (x_blocking % 2 == 0 && x_blocking >= min_x_blocking * 2) + x_blocking /= 2; + else if (x_blocking % 3 == 0 && x_blocking >= min_x_blocking * 3) + x_blocking /= 3; + else + break; + } + if (x_blocking >= min_x_blocking * 4) x_blocking = 1; + x_blocking *= conf_.x_block_; + return x_blocking; +} + +template +void cpu_reducer_2d_t::reduce_block(const data_t* space_base, + data_t *dst, int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const { + data_t *d = dst + (start_y + ny_start) * conf_.dst_x_ + + start_x + nx_start; + const data_t *space = space_base + job * balancer().job_size_ + + ny_start * conf_.job_size_x_ + nx_start; +#ifdef SIMPLE_IMPL + for (int idg = 0; idg < balancer().nthr_per_group_; ++idg) { + const data_t *w = &space[idg * space_per_thread(balancer())]; + for (int y = 0; y < ny_step; ++y) + for (int x = 0; x < nx_step; ++x) { + d[y * conf_.dst_x_ + x] + = (idg == 0 ? 0 : d[y * conf_.dst_x_ + x]) + + w[y * conf_.job_size_x_ + x]; + } + } +#else + (*drv_)(d, space, ny_step, nx_step); +#endif +} + +template +void cpu_reducer_2d_t::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const int njobs_x = utils::div_up(conf_.dst_x_, conf_.job_size_x_); + const int global_job_start = balancer().ithr_job_off(ithr); + + const data_t *space_base = get_local_ptr(ithr - id_in_grp, scratchpad); + + const int pr_grps = nstl::min(njobs_in_grp, balancer().nthr_per_group_); + const int pr_nthr_per_grp = balancer().nthr_per_group_ / pr_grps; + + if (id_in_grp >= pr_grps * pr_nthr_per_grp) + return; /* idle */ + + const int pr_my_grp = id_in_grp / pr_nthr_per_grp; + const int pr_my_id = id_in_grp % pr_nthr_per_grp; + + int pr_job_start{0}, pr_job_end{0}; + balance211(njobs_in_grp, pr_grps, pr_my_grp, pr_job_start, pr_job_end); + + for (int j = pr_job_start; j < pr_job_end; ++j) { + const int global_job = global_job_start + j; + const int j_y = global_job / njobs_x; + const int j_x = global_job % njobs_x; + const int start_y = j_y * conf_.job_size_y_; + const int start_x = j_x * conf_.job_size_x_; + const int ny = nstl::min(conf_.dst_y_ - start_y, conf_.job_size_y_); + const int nx = nstl::min(conf_.dst_x_ - start_x, conf_.job_size_x_); + int x_blocking = choose_x_blocking(nx, ny, pr_nthr_per_grp); + + int nxy_start{0}, nxy_end{0}; + balance211(ny * nx / x_blocking, pr_nthr_per_grp, pr_my_id, + nxy_start, nxy_end); + if (nxy_start == nxy_end) continue; + nxy_start *= x_blocking; + nxy_end *= x_blocking; + + int nxy = nxy_start; + if (nxy % nx != 0) { + int nx_step = nstl::min(nx - nxy % nx, nxy_end - nxy); + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nx_step); + nxy += nx_step; + } + if ((nxy_end - nxy) > nx) { + int ny_step = (nxy_end - nxy) / nx; + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, ny_step, nx); + nxy += nx * ny_step; + } + if ((nxy_end - nxy) > 0) { + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nxy_end - nxy); + } + } +} + +template struct cpu_reducer_2d_t; +template struct cpu_reducer_2d_t; + +/* accumulator section */ + +template +cpu_accumulator_1d_t::cpu_accumulator_1d_t(): drv_(nullptr) { + drv_ = create_reduce_2d_drv(1, 0, 0, 0, false); +} + +template +cpu_accumulator_1d_t::~cpu_accumulator_1d_t() { + delete drv_; +} + +template +void cpu_accumulator_1d_t::accumulate(data_t *dst, + const data_t *src, size_t size) { + (*drv_)(dst, src, 1, size); +} + +template struct cpu_accumulator_1d_t; +template struct cpu_accumulator_1d_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp new file mode 100644 index 0000000000..27f5939cd2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp @@ -0,0 +1,334 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REDUCER_HPP +#define CPU_REDUCER_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/** class to perform balancing over 3D array + * + * Conceptually the reduction happens according to the picture below: + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ result + * + * In a simple case the result must be contiguous in memory. + * @class cpu_reducer_t is an implementation. + * + * Threads are divided into groups. The groups are independent of each other. + * Each group may work on several jobs (the distribution is not uniform, since + * njobs might be not a multiple of groups). Threads within a group work on + * different parts of the reduction dimension. Thread 0 in each group is called + * master (@sa reduce_balancer_t::master()). + * + * If threading driver does not allow sync between sub-group of threads (e.g. + * Intel(R) TBB) the # of thread per group is enforced to be 1. + */ +struct reduce_balancer_t { + reduce_balancer_t() { init(1, 1, 1, 1, 0); } /* trivial balance */ + reduce_balancer_t(int nthr, int job_size, int njobs, int reduction_size, + size_t max_buffer_size) + { init(nthr, job_size, njobs, reduction_size, max_buffer_size); } + + reduce_balancer_t &init(int nthr, int job_size, int njobs, + int reduction_size, size_t max_buffer_size) + { + syncable_ = mkldnn_thr_syncable(); + nthr_ = nthr; + job_size_ = job_size; + njobs_ = njobs; + reduction_size_ = reduction_size; + max_buffer_size_ = max_buffer_size; + balance(); + return *this; + } + + bool syncable_; + int nthr_; + int job_size_, njobs_, reduction_size_; + + int ngroups_; /** number of independent work (thread) groups */ + int nthr_per_group_; /** number of threads within a single work group */ + int njobs_per_group_ub_; /** the max # of jobs within a work group */ + + bool master(int ithr) const { return id_in_group(ithr) == 0; } + bool idle(int ithr) const { return ithr >= nthr_per_group_ * ngroups_; } + + int group_id(int ithr) const { return ithr / nthr_per_group_; } + int id_in_group(int ithr) const { return ithr % nthr_per_group_; } + + int grp_njobs(int grp) const { + if (grp >= ngroups_) return 0; + return njobs_ / ngroups_ + (grp < njobs_ % ngroups_); + } + int grp_job_off(int grp) const { + if (grp >= ngroups_) return njobs_; + return njobs_ / ngroups_ * grp + nstl::min(grp, njobs_ % ngroups_); + } + + int ithr_njobs(int ithr) const { return grp_njobs(group_id(ithr)); } + int ithr_job_off(int ithr) const { return grp_job_off(group_id(ithr)); } + +private: + size_t max_buffer_size_; + void balance(); +}; + +/** forward declaration of reduce driver */ +template struct reducer_2d_driver_t; + +/** class to perform a reduction over 3D array + * + * Balancing is based on @class reduce_balancer_t. + * Restrictions: the result of the reduction must be contiguous in memory. * + * The reduction happens according to the picture below (once more): + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ (contiguous) result + * + * An example how work might be shared is shown below. + * + * In this example group 0 owns 2 (independent) jobs -- 2 big squares. + * The number of threads per group is also 2 (thread 0 of group 0 and thread 1 + * of group 0). Master threads (i.e. threads with id 0 in corresponding group) + * from each group put the partial result directly into destination memory, + * while all the other threads with-in the group use workspace (on the picture + * the only thread 1). Once intermediate results obtained each group reduces + * corresponding part (own jobs) to the destination memory. + * + * <------- group 0 -------> + * + * +-----------+ +-----------+ ^ + * | | | | | thread 0 of reduces to the dest-memory + * | | | | | group 0 +-----------+ +-----------+ + * |- - - - - -| |- - - - - -| X + * | | | | | thread 1 of reduces to workspace[tid=1]: + * | | | | | group 0 +-----------+ +-----------+ + * +-----------+ +-----------+ v + * | | | | | | + * v v v v v v + * ((barrier)) ============================= + * + * dest-memory: +-----------+ +-----------+ + */ +template +struct cpu_reducer_t { + typedef typename prec_traits::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer) + { balancer_ = balancer; return *this; } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + }; + + cpu_reducer_t(const conf_t &conf); + ~cpu_reducer_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results. + * Reduction destination @p dst must be provided as well (master threads + * from each group will use it for partial result to reduce memory + * pressure). + * + * @note: job offset is already applied by get_local_ptr(), which means all + * threads should start writing from the very beginning of returned + * address. + */ + data_t *get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t *drv_; + + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +template +struct cpu_reducer_2d_t { + typedef typename prec_traits::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer, int job_size_x, + int job_size_y, int x_block, int dst_x, int dst_y) { + balancer_ = balancer; + job_size_x_ = job_size_x; + job_size_y_ = job_size_y; + x_block_ = x_block; + dst_x_ = dst_x; + dst_y_ = dst_y; + return *this; + } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + int job_size_x_, job_size_y_, x_block_, dst_x_, dst_y_; + }; + + cpu_reducer_2d_t(const conf_t &conf); + ~cpu_reducer_2d_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results */ + data_t *get_local_ptr(int ithr, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t *drv_; + + int choose_x_blocking(int nx, int ny, int nthr_per_grp) const; + void reduce_block(const data_t* space_base, data_t *dst, + int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const; + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +/** simple 1d accumulator: y[:] += x[:] */ +template +struct cpu_accumulator_1d_t { + typedef typename prec_traits::type data_t; + + cpu_accumulator_1d_t(); + ~cpu_accumulator_1d_t(); + void accumulate(data_t *dst, const data_t *src, size_t size); + + reducer_2d_driver_t *drv_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp new file mode 100644 index 0000000000..82be70353d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp @@ -0,0 +1,262 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "cpu_engine.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_memory.hpp" +#include "type_helpers.hpp" + +#include "cpu/jit_uni_reorder.hpp" +#include "cpu/simple_reorder.hpp" +#include "cpu/wino_reorder.hpp" +#include "cpu/rnn/rnn_reorders.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; + +#define REG_SR(idt, ifmt, odt, ofmt, ...) \ + simple_reorder_t::pd_t::create + +#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) + +#define REG_SR_DIRECT_COPY(idt, odt) \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) + +static const rpd_create_f cpu_reorder_impl_list[] = { + /* winograd */ + wino_reorder_t::pd_t::create, + //wino_reorder_t::pd_t::create, + + /* rnn reorders */ + rnn_data_reorder_t::pd_t::create, + rnn_weights_reorder_t::pd_t::create, + rnn_weights_reorder_t::pd_t::create, + + /* conv reorders w/ compensation */ + REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + + /* regular reorders */ + +#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) + /* Direct copy for icc which is faster than jitted code; + * Direct copy for gcc which might or might not be faster than jitted + * code, but still worth it because doesn't require jitting, i.e. much + * faster creation time. This is tentative solution and should be removed + * later (when we will cache jitted code?...). */ + REG_SR_DIRECT_COPY(f32, f32), +#endif + +#ifdef __INTEL_COMPILER + /* direct copy for icc, which is faster than jitted code */ + /* + REG_SR_DIRECT_COPY(f32, s32), + REG_SR_DIRECT_COPY(f32, s8), + REG_SR_DIRECT_COPY(f32, u8), + REG_SR_DIRECT_COPY(s32, f32), + REG_SR_DIRECT_COPY(s32, s32), + REG_SR_DIRECT_COPY(s32, s8), + REG_SR_DIRECT_COPY(s32, u8), + REG_SR_DIRECT_COPY(s8, f32), + REG_SR_DIRECT_COPY(s8, s32), + REG_SR_DIRECT_COPY(s8, s8), + REG_SR_DIRECT_COPY(s8, u8), + REG_SR_DIRECT_COPY(u8, f32), + REG_SR_DIRECT_COPY(u8, s32), + REG_SR_DIRECT_COPY(u8, s8), + REG_SR_DIRECT_COPY(u8, u8), + */ +#endif + + /* jit */ + jit_uni_reorder_create, + + /* fp32: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, f32, nCw4c), + REG_SR_BIDIR(f32, any, f32, nCw8c), + REG_SR_BIDIR(f32, any, f32, OIw4i4o), + REG_SR_BIDIR(f32, any, f32, OIw8i8o), + REG_SR_BIDIR(f32, any, f32, OIw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCw16c), + REG_SR_BIDIR(f32, any, f32, OIw16o16i), + REG_SR_BIDIR(f32, any, f32, OIw16i16o), + REG_SR_BIDIR(f32, any, f32, IOw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOw16o16i), + + REG_SR_BIDIR(f32, any, f32, nChw4c), + REG_SR_BIDIR(f32, any, f32, nChw8c), + REG_SR_BIDIR(f32, any, f32, OIhw4i4o), + REG_SR_BIDIR(f32, any, f32, Ohwi8o), + + REG_SR_BIDIR(f32, any, f32, OIhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIhw4o4i), + REG_SR_BIDIR(f32, any, f32, gOhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nChw16c), + REG_SR_BIDIR(f32, any, f32, Oihw4o), + REG_SR_BIDIR(f32, any, f32, Oihw16o), + REG_SR_BIDIR(f32, any, f32, Ohwi4o), + REG_SR_BIDIR(f32, any, f32, Ohwi16o), + REG_SR_BIDIR(f32, any, f32, OIhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIhw16i16o), + REG_SR_BIDIR(f32, any, f32, IOhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOihw4o), + REG_SR_BIDIR(f32, any, f32, gOihw16o), + REG_SR_BIDIR(f32, any, f32, gOhwi4o), + REG_SR_BIDIR(f32, any, f32, gOhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIhw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOhw16o16i), + + REG_SR_BIDIR(f32, any, f32, nCdhw4c), + REG_SR_BIDIR(f32, any, f32, nCdhw8c), + REG_SR_BIDIR(f32, any, f32, OIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, Odhwi8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOdhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCdhw16c), + REG_SR_BIDIR(f32, any, f32, Oidhw4o), + REG_SR_BIDIR(f32, any, f32, Oidhw16o), + REG_SR_BIDIR(f32, any, f32, Odhwi16o), + REG_SR_BIDIR(f32, any, f32, OIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIdhw16i16o), + REG_SR_BIDIR(f32, any, f32, gOidhw4o), + REG_SR_BIDIR(f32, any, f32, gOidhw16o), + REG_SR_BIDIR(f32, any, f32, gOdhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o), + */ + + /* fp32: blocked <-> blocked with tail */ + REG_SR_BIDIR(f32, nCw8c, f32, nCw16c), + REG_SR_BIDIR(f32, nChw8c, f32, nChw16c), + REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c), + + /* int: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, s32, nChw16c), + REG_SR_BIDIR(f32, any, s8, nChw16c), + REG_SR_BIDIR(f32, any, u8, nChw16c), + REG_SR_BIDIR(s32, any, f32, nChw16c), + REG_SR_BIDIR(s32, any, s32, nChw16c), + REG_SR_BIDIR(s32, any, s8, nChw16c), + REG_SR_BIDIR(s32, any, u8, nChw16c), + REG_SR_BIDIR(s8, any, f32, nChw16c), + REG_SR_BIDIR(s8, any, s32, nChw16c), + REG_SR_BIDIR(s8, any, s8, nChw16c), + REG_SR_BIDIR(s8, any, u8, nChw16c), + REG_SR_BIDIR(u8, any, f32, nChw16c), + REG_SR_BIDIR(u8, any, s32, nChw16c), + REG_SR_BIDIR(u8, any, s8, nChw16c), + REG_SR_BIDIR(u8, any, u8, nChw16c), + + REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i), + */ + + /* reference: the last line of defence */ + /* + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(f32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s8, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(u8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, u8, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s8, any, fmt_order::any, spec::reference), + */ + + /* eol */ + nullptr, +}; +} + +const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const { + return cpu_reorder_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp new file mode 100644 index 0000000000..1622eb6849 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REORDER_PD_HPP +#define CPU_REORDER_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "reorder_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_reorder_pd_t: public reorder_pd_t { + using reorder_pd_t::reorder_pd_t; + + status_t init() { + const auto &post_ops = attr()->post_ops_; + bool args_ok = IMPLICATION(post_ops.len_ != 0, post_ops.len_ == 1 + && post_ops.entry_[0].kind == primitive_kind::sum); + scratchpad_engine_ = src_engine_; + return args_ok ? status::success : status::unimplemented; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp new file mode 100644 index 0000000000..f16587b99f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SHUFFLE_PD_HPP +#define CPU_SHUFFLE_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "shuffle_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_shuffle_pd_t: public shuffle_pd_t { + using shuffle_pd_t::shuffle_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp new file mode 100644 index 0000000000..3a39eab974 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SOFTMAX_PD_HPP +#define CPU_SOFTMAX_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "softmax_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_softmax_fwd_pd_t: public softmax_fwd_pd_t { + using softmax_fwd_pd_t::softmax_fwd_pd_t; +}; + +struct cpu_softmax_bwd_pd_t: public softmax_bwd_pd_t { + using softmax_bwd_pd_t::softmax_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp new file mode 100644 index 0000000000..1ab5d9f174 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_sum.hpp" +#include "cpu/simple_sum.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using spd_create_f = mkldnn::impl::engine_t::sum_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const spd_create_f cpu_sum_impl_list[] = { + /* + INSTANCE(simple_sum_t), + INSTANCE(ref_sum_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const spd_create_f *cpu_engine_t::get_sum_implementation_list() const { + return cpu_sum_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp new file mode 100644 index 0000000000..0965129f9b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SUM_PD_HPP +#define CPU_SUM_PD_HPP + +#include "c_types_map.hpp" +#include "sum_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_sum_pd_t: public sum_pd_t { + using sum_pd_t::sum_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp new file mode 100644 index 0000000000..a9810dec28 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp @@ -0,0 +1,372 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "gemm_utils_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace gemm_utils { +#define BM_NOCOPY_AVX 64 +#define BN_NOCOPY_AVX 48 +#define BK_NOCOPY_AVX 384 +#define BN_LARGE_NOCOPY_AVX 192 +#define BM_SMALL_NOCOPY_AVX 16 +#define BN_SMALL_NOCOPY_AVX 1 +#define BK_SMALL_NOCOPY_AVX 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + + nthr = nthrs; + nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX; + nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX; + nthr_k = 1; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + int nthr_other = nthr_k = 1; + while ((nthr_m * nthr_n * nthr_other < nthr) + && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) { + nthr_other++; + if ((nthr / nthr_other) * nthr_other > 0.9 * nthr) + nthr_k = nthr_other; + } + } + nthr /= nthr_k; + + if (nthr_m == 1) + nthr_n = nthr; + if (nthr_n == 1) + nthr_m = nthr; + + // Simple partition reduction + while (nthr_m * nthr_n > nthr) + if (nthr_m > nthr_n) + nthr_m--; + else + nthr_n--; + while (nthr_m * nthr_n < nthr) + if (nthr_m < nthr_n) + nthr_m++; + else + nthr_n++; + + if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX) + nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX) + nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX +#undef BN_NOCOPY_AVX +#undef BK_NOCOPY_AVX +#undef BN_LARGE_NOCOPY_AVX +#undef BM_SMALL_NOCOPY_AVX +#undef BN_SMALL_NOCOPY_AVX +#undef BK_SMALL_NOCOPY_AVX + +#define BM_NOCOPY_AVX512_COMMON 32 +#define BN_NOCOPY_AVX512_COMMON 64 +#define BK_NOCOPY_AVX512_COMMON 192 +#define BN_LARGE_NOCOPY_AVX512_COMMON 192 +#define BM_SMALL_NOCOPY_AVX512_COMMON 16 +#define BN_SMALL_NOCOPY_AVX512_COMMON 1 +#define BK_SMALL_NOCOPY_AVX512_COMMON 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k = 1; + int MB, NB, KB; + nthr = nthrs; + + int counter = 0; + float ratio_float = 1.; + int ratio = 1; + nthr = nthrs; + int nthr_m_gt_n; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + if (n <= 2 * BN_NOCOPY_AVX512_COMMON && + m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) { + nthr_k = k / BK_NOCOPY_AVX512_COMMON; + if (nthr_k > nthr / 4) + nthr_k = nthr / 4; + if (nthr_k < 1) + nthr_k = 1; + + while ((nthr_k > 1) && (nthr % nthr_k)) { + nthr_k--; + } + nthr /= nthr_k; + } else { + nthr_k = 1; + } + } + nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON; + nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON; + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0; + ratio_float = (float)nthr_m / nthr_n; + + if (nthr_m_gt_n) + ratio = (int)ratio_float; + else + ratio = (int)(1. / ratio_float); + + // scale down nthr_m and nthr_n if they are too large + while (nthr_m * nthr_n > 4 * nthr) { + nthr_m /= 2; + nthr_n /= 2; + } + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + // Simple partition reduction + counter = 0; + while (nthr_m * nthr_n > nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m--; + else { + nthr_n--; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n--; + else { + nthr_m--; + counter = -1; + } + } + counter++; + } + + // Simple partition increment + counter = 0; + while (nthr_m * nthr_n < 0.95 * nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m++; + else { + nthr_n++; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n++; + else { + nthr_m++; + counter = -1; + } + } + counter++; + } + + // if nothing works out, then this should work + if ((nthr_m * nthr_n > nthr)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON) + nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON) + nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX512_COMMON +#undef BN_NOCOPY_AVX512_COMMON +#undef BK_NOCOPY_AVX512_COMMON +#undef BN_LARGE_NOCOPY_AVX512_COMMON +#undef BM_SMALL_NOCOPY_AVX512_COMMON +#undef BN_SMALL_NOCOPY_AVX512_COMMON +#undef BK_SMALL_NOCOPY_AVX512_COMMON + +// Partition n values as equally as possible among nthr threads +// and set the offset (t_offset) and number of values (t_block) for ithr +// Assumption: 0 <= ithr < nthr +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block) +{ + int band = n / nthr; + if (band == 0) + band = 1; + int tail = n - band * nthr; + if (tail < 0) + tail = 0; + + if (ithr < tail) { + band++; + *t_offset = band * ithr; + *t_block = band; + } else { + *t_offset = band * ithr + tail; + *t_block = band; + } + + if (*t_offset >= n) { + *t_offset = 0; + *t_block = 0; + } + + if (*t_offset + *t_block > n) { + *t_block = n - *t_offset; + } +} + +// Sum the m*n values from p_src into p_dst, assuming the two-dimensional +// arrays have leading dimensions ld_src and ld_dst, respectively +template +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst) +{ + int i, j; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + p_dst[i + j * ld_dst] += p_src[i + j * ld_src]; + } + } +} + +template +void sum_two_matrices(int m, int n, + float * __restrict p_src, dim_t ld_src, + float * __restrict p_dst, dim_t ld_dst); + +template +void sum_two_matrices(int m, int n, + double * __restrict p_src, dim_t ld_src, + double * __restrict p_dst, dim_t ld_dst); +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp new file mode 100644 index 0000000000..3352298b4a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_UTILS_HPP +#define GEMM_UTILS_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace gemm_utils { +// Alias for any dimension related variable. +typedef ptrdiff_t dim_t; + +template +struct gemm_traits {}; + +template +struct gemm_traits { + static constexpr int m = 8; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 192; + static constexpr int BK = isTransB ? 96 : 512; +}; + +template +struct gemm_traits { + static constexpr int m = 16; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 48; + static constexpr int BK = isTransB ? 96 : 256; +}; + +template +using unroll_factor = gemm_traits; + +template +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst); + +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK); + +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK); + +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp new file mode 100644 index 0000000000..d7be43e392 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp @@ -0,0 +1,2131 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx512_common_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#ifdef _WIN32 +#define STACK_K_CAPACITY 32 +#else +#define STACK_K_CAPACITY 2048 +#endif +#define SIZE 4 +#define OFFSET 128 +#define BASE_SHIFT 2 +#define SECOND_FETCH unroll_n +#define UNROLL_M 48 +#define UNROLL_N 8 + +namespace avx512_common_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + enum { ver_avx512_core, ver_avx512_mic } ver = + mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto ORIG_SP = qword[rsp + 120]; + + auto ZSTRIDE = zmm4; + auto VALPHA = zmm6; + auto VBETA = zmm7; + auto VBIAS1 = zmm1; + auto VBIAS2 = zmm2; + auto VBIAS3 = zmm3; + + auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80; + auto PREFETCHSIZEB = 16; + + Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15, + zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, + zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 }; + + // Function for packing if needed + auto do_pack = [&](int unroll_m) { + Label pack2, pack3, pack4, pack10; + + mov(BO1, A); + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE] + | k1, + zmm0); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE] + | k2, + zmm1); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE] + | k3, + zmm2); + } + } else { + for (int i = 0; i < 4; i++) { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, + ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + add(BO1, 4 * SIZE); + } + add(AO1, unroll_m * 4 * SIZE); + + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + } else { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + add(BO1, SIZE); + } + + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE], + zmm2 | k3); + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Function to update C, covering masking and other considerations + auto update = [&](Zmm reg, bool useCO1, int offset, int mask, + bool useScale = false) { + vmulps(reg, reg, VALPHA); + if (!isBeta0) { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(zmm0, reg, zmm0); + } else { + vfmadd132ps(zmm0, reg, VBETA); + } + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3); + break; + } + } + } else { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg); + else + vmovups(ptr[CO2 + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3); + break; + } + } + } + vpxorq(reg, reg, reg); + }; + + // Loop with unroll_n - 2 FMAs; called by innerkernel + auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) { + for (int i = 2; i < unroll_n; i++) { + if (ver == ver_avx512_core) { + if (!isTransB) { + switch (i) { + case 2: + vbroadcastss( + zmm3, + ptr[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, + ptr[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, + ptr[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, + ptr[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } else { + if (!isTransB) { + switch (i) { + case 2: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vfmadd231ps( + regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } + } + } + }; + + // Innerkernel; called by kernel + auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool doCPrefetch, bool isUnmasked = true) { + for (int i = 0; i < 8; i++) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]); + } + + if (!isDirect) { + if (i != 0) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } + } + } + } else { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[0], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + } + } + + if (unroll_n >= i + 1) { + if (!isTransB) { + switch (i) { + case 0: + prefetcht0( + ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 1: + prefetcht0(ptr[BO1 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 2: + prefetcht0(ptr[BO1 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 3: + prefetcht0(ptr[BO1 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 4: + prefetcht0( + ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 5: + prefetcht0(ptr[BO2 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 6: + prefetcht0(ptr[BO2 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 7: + prefetcht0(ptr[BO2 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + } + } + } + + if (unroll_n >= 2) { + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, + ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[1], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + } + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + if (i == 7) + sub(LDA4, -unroll_m * 8 * SIZE); + } + fmaloop(unroll_m, unroll_n, i); + + if (i == 1) { + if (doCPrefetch) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 0 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 0 * 16 * SIZE]); + } + } + if (i == 3) { + if (doCPrefetch && unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 1 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 1 * 16 * SIZE]); + } + if (!isTransA) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 0 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 0 * SIZE]); + } + } + if (i == 5) { + if (doCPrefetch) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 2 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 2 * 16 * SIZE]); + } + add(CO2, LDC); + } + if (!isTransA) { + if (unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 1 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 1 * SIZE]); + } + } + } + + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + } // end of for loop + + if (!isTransB) { + sub(BO1, -8 * SIZE); + if (unroll_n >= 4) + sub(BO2, -8 * SIZE); + } + if (!isTransA) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 2 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 2 * SIZE]); + } + lea(AA, ptr[AA + LDA]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * 8 * SIZE); + } + + sub(LL, 1); + }; + + // Main kernel; does prefetching and calls innerkernel + // After calculating results in registers, writes back to C matrix by + // calling update + auto kernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool isUnmasked = true) { + if (!isDirect) { + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + auto step = ver == ver_avx512_core ? 2 : 4; + lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } + } + } + + Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18; + + mov(LL, K); + sar(LL, 3); + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel( + unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + lea(CO2, ptr[CO1 + (16 - 1) * SIZE]); + add(LL, unroll_n); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + mov(LL, K); + and_(LL, 7); + jle(kernel18, T_NEAR); + align(16); + + L(kernel16); + if (isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + for (int i = 0; i < unroll_n; i++) { + if (!isTransB) { + switch (i) { + case 0: + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + break; + case 1: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 2: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) { + vfmadd231ps(regs[i + 8], zmm3, zmm1); + } + if (unroll_m >= 48) { + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + sub(LL, 1); + jg(kernel16, T_NEAR); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta cases need to be handled + if (hasBias) { + mov(BIAS1, BIAS); + if (isUnmasked || unroll_m > 16) + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + else + vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]); + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) + vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]); + else + vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]); + } + if (unroll_m >= 48) { + if (isUnmasked) + vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]); + else + vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + bool useScale = i % 2 != 0; + bool useCO1 = i < 2; + if (i == 2) + lea(CO2, ptr[CO1 + LDC * 2]); + if (i == 4 || i == 6) + lea(CO2, ptr[CO2 + LDC * 2]); + if (hasBias) + vaddps(regs[i], VBIAS1, regs[i]); + if (isUnmasked || unroll_m > 16) { + update(regs[i], useCO1, 0, 0, useScale); + } else { + update(regs[i], useCO1, 0, 1, useScale); + } + if (unroll_m >= 32) { + if (hasBias) + vaddps(regs[i + 8], VBIAS2, regs[i + 8]); + if (isUnmasked || unroll_m > 32) { + update(regs[i + 8], useCO1, 16, 0, useScale); + } else { + update(regs[i + 8], useCO1, 16, 2, useScale); + } + } + if (unroll_m >= 48) { + if (hasBias) + vaddps(regs[i + 16], VBIAS3, regs[i + 16]); + if (isUnmasked) { + update(regs[i + 16], useCO1, 32, 0, useScale); + } else { + update(regs[i + 16], useCO1, 32, 3, useScale); + } + } + } + + switch (unroll_n) { + case 1: add(CO1, LDC); break; + case 2: lea(CO1, ptr[CO1 + LDC * 2]); break; + case 3: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 4: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 5: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 6: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 7: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 8: lea(CO1, ptr[CO2 + LDC * 2]); break; + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + case 7: + lea(BO1, ptr[BO1 + LDB * 8]); + sub(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 8]); + sub(BO2, LDB); + break; + case 8: + lea(BO1, ptr[BO1 + LDB * 8]); + lea(BO2, ptr[BO2 + LDB * 8]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 48 rows, 8 columns at a time (handling tail + // cases appropriately by doing 32 or 16 rows, and/or with masking, + // and/or fewer columns). + auto subloop = [&](int unroll_m) { + Label l_subloop_20x[8], l_subloop_mask_20x[8]; + Label l_subloop_30x[8], l_subloop_mask_30x[8]; + + Label subloop11, subloop11mask; + Label subloop30, subloop30mask; + Label subloop31, subloop31mask; + Label subloop96; + Label subloop98, subloop98mask; + Label subloop99; + + // Create mask + mov(BO1, rcx); + mov(rcx, M); + sub(rcx, unroll_m - 16); + mov(CO1, 16); + cmp(rcx, 16); + + cmovg(rcx, CO1); + mov(rax, 1); + sal(rax, cl); + sub(rax, 1); + mov(rcx, 0xffff); + + if (unroll_m == 16) { + kmovw(k1, eax); + } else if (unroll_m == 32) { + kmovw(k1, ecx); + kmovw(k2, eax); + } else { + kmovw(k1, ecx); + kmovw(k2, ecx); + kmovw(k3, eax); + } + mov(rcx, BO1); + + and_(rax, 0xffff); + cmp(rax, 0xffff); + jne(subloop96, T_NEAR); + + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true); + } else { + kernel(unroll_m, UNROLL_N, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + align(16); + + L(subloop11); + kernel(unroll_m, UNROLL_N, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(l_subloop_30x[1], T_NEAR); + align(16); + + L(subloop31); + kernel(unroll_m, UNROLL_N, true, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop96); + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98mask, T_NEAR); + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98mask); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30mask, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30mask, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true, false); + } else { + kernel(unroll_m, UNROLL_N, false, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + align(16); + + L(subloop11mask); + kernel(unroll_m, UNROLL_N, false, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30mask); + cmp(I, UNROLL_N); + jl(l_subloop_mask_30x[1], T_NEAR); + align(16); + + L(subloop31mask); + kernel(unroll_m, UNROLL_N, true, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + imul(rax, rax, 0x30); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + if (isTransA) { + vpbroadcastq(zmm2, LDA); + vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE); + mov(rax, -2); + kmovw(k4, eax); + + for (int i = 0; i < 6; i++) { + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + kshiftlw(k4, k4, 1); + } + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f); + mov(FLAG, rax); + + for (int i = 8; i < 16; i++) { + for (int j = 0; j < 3; j++) { + vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j)); + } + } + + Label main0, main1, main2, main999; + + cmp(M, 32); + jle(main0, T_NEAR); + align(16); + + L(main1); + subloop(48); + sub(M, UNROLL_M); + cmp(M, 32); + jg(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 16); + jle(main2, T_NEAR); + + subloop(32); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 0); + jle(main999, T_NEAR); + subloop(16); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032, BN, BK; + if (mayiuse(avx512_core)) { + BN = isTransA ? 384 : 64; + BK = 384; + } else { + BN = isTransA ? 96 : 64; + BK = isTransB ? 96 : 192; + if (!isTransA && !isTransB) + BK = 128; + } + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx512_common_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx512_common( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 48 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp new file mode 100644 index 0000000000..d581b7fd71 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP +#define JIT_AVX512_COMMON_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp new file mode 100644 index 0000000000..60d4220837 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp @@ -0,0 +1,2705 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#if _WIN32 +#define STACK_K_CAPACITY 128 +#else +#define STACK_K_CAPACITY 8192 +#endif +#define SIZE 4 +#define OFFSET 32 +#define BASE_SHIFT 2 +#define SECOND_FETCH 14 + +namespace avx_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + const bool is_avx2 = mayiuse(avx2); + assert(IMPLICATION(!is_avx2, mayiuse(avx))); + + const int UNROLL_M = is_avx2 ? 16 : 8; + const int UNROLL_N = 6; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto MASK = dword[rsp + 88]; + auto STRIDE = qword[rsp + 120]; + auto ORIG_SP = qword[rsp + 152]; + + auto VALPHA = ymm1; + auto VBETA = ymm2; + auto VMASK = ymm3; + auto VBIAS1 = ymm2; + auto VBIAS2 = ymm4; + + auto PREFETCHSIZEA = 128; + auto PREFETCHSIZEB = (!isTransB) ? -16 : 0; + + // Function for packing if needed + auto do_pack = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + Label pack2, pack3, pack4, pack10; + + int regIdx; + Reg64 reg; + + mov(BO1, A); + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + + if (isTransA) { + lea(BO2, ptr[BO1 + LDA * 4]); + lea(CO1, ptr[LDA + LDA * 2]); + vmovupd(ymm7, STRIDE); + } + + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + regIdx = (i % 2 == 0) ? 4 : 6; + if (isLoad1Unmasked) { + vmovups(Ymm(regIdx), + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx), VMASK, + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(Ymm(regIdx + 1), + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx + 1), VMASK, + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + Ymm(regIdx)); + if (unroll_m > 8) { + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + Ymm(regIdx + 1)); + } + } + + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + i * 4 - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + i * 4 - OFFSET) + * SIZE], + xmm1); + } + } else if (is_avx2) { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 0 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 0 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO1 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 1 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 1 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovups(Xmm(ld_step % 2), ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + for (int off = 0; off < 4; ++off) + pextrd(ptr[dst_addr + unroll_m * off * SIZE], + Xmm(ld_step % 2), off); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + if (i == 0) + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + } + } else { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 2 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 2 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 3 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 3 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } + } + add(BO1, (4 * SIZE)); + } + + add(AO1, unroll_m * 4 * SIZE); + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + if (isLoad1Unmasked) { + vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm5, VMASK, + ptr[BO1 + (1 + 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm4); + if (unroll_m > 8) { + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm5); + } + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else if (is_avx2) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovss(xmm1, ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + movss(ptr[dst_addr], xmm1); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE], + xmm1); + + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE], + xmm1); + } + add(BO1, SIZE); + } + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Fused multiply add; may become one or two instructions + auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2, + bool overWrite = false) { + if (useFma) { + if (is_avx2) { + vfmadd231ps(reg2, reg1, reg0); + } else { + assert(UNROLL_M == 8); + auto tent_vreg = overWrite ? reg1 : ymm1; + vmulps(tent_vreg, reg1, reg0); + vaddps(reg2, reg2, tent_vreg); + } + } else { + if (!overWrite) { + vmulps(ymm15, reg1, reg0); + vaddps(reg2, reg2, ymm15); + } else { + vmulps(reg1, reg1, reg0); + vaddps(reg2, reg2, reg1); + } + } + }; + + // Inner kernel with k=8 + auto innerkernel8 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 8; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 7) { + sub(LDA4, -unroll_m * 8 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3 || i == 4 || i == 5 || i == 6) { + if (unroll_m >= 16) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 7) { + if (!isTransB) { + if (unroll_n >= 4) { + sub(BO2, -8 * SIZE); + } + } + if (!isTransA) { + prefetcht2(ptr[AA]); + lea(AA, ptr[AA + LDA]); + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 8 * SIZE); + } + sub(LL, 1); + + }; + + // Inner kernel with k=4 + auto innerkernel4 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 4; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 3) { + sub(LDA4, -unroll_m * 4 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3) { + if (!isTransB) { + sub(BO1, -4 * SIZE); + if (unroll_n >= 4) { + sub(BO2, -4 * SIZE); + } + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 4 * SIZE); + } + + }; + + // Inner kernel with k=2 + auto innerkernel2 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + for (int i = 0; i < 2; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + } + + }; + + // Inner kernel with k=1 + auto innerkernel1 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) { + + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg00); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg06); + } + + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg01); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg07); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg02); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg08); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg03); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg09); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg04); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg10); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg05); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg11); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + }; + + // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as + // appropriate + // After calculating results in registers, writes back to C matrix + auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma, + Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6), + Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9), + Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12), + Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15), + Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6), + Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9), + Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12), + Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) { + if (!isDirect) { + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDB * 2]); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } + } + } + + for (int i = 4; i < 10; i++) { + vxorps(Ymm(i), Ymm(i), Ymm(i)); + vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6)); + } + + mov(LL, K); + sar(LL, 3); + + Label kernel12, kernel13, kernel14, kernel15; + Label kernel16, kernel17, kernel18; + + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 2) + prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 3) + prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 4) + prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 5) + prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 6) + prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]); + + add(LL, SECOND_FETCH); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + test(K, 4); + jle(kernel16, T_NEAR); + innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + + L(kernel16); + test(K, 2); + jle(kernel17, T_NEAR); + innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + align(16); + + L(kernel17); + if (unroll_m == 16) { + if (unroll_n <= 3) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg06, reg06, reg18); + vaddps(reg07, reg07, reg19); + vaddps(reg08, reg08, reg20); + } + } + + if (unroll_m <= 8) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg03, reg03, reg15); + vaddps(reg04, reg04, reg16); + vaddps(reg05, reg05, reg17); + } + + test(K, 1); + jle(kernel18, T_NEAR); + innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta and bias cases need to be + // handled + switch (unroll_n) { + case 1: mov(rax, LDC); break; + case 2: lea(rax, ptr[LDC * 2]); break; + case 3: lea(rax, ptr[LDC + LDC * 2]); break; + case 4: lea(rax, ptr[LDC + LDC * 4]); break; + case 5: + lea(rax, ptr[LDC * 4]); + add(rax, LDC); + break; + case 6: + lea(rax, ptr[LDC + LDC * 2]); + add(rax, rax); + break; + } + + if (hasBias) { + mov(BIAS1, BIAS); + if (isLoad1Unmasked) { + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + } else { + vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA); + if (!isBeta0) { + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break; + case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break; + case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]); + break; + case 1: + vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]); + break; + case 2: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]); + break; + case 4: + vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]); + break; + case 5: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } + + if (!isBetaN) { + vaddps(Ymm(i + 4), ymm0, Ymm(i + 4)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 4), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4)); + } + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break; + case 1: + vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break; + case 4: + vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 1: + vmaskmovps( + ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + case 3: + vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 4: + vmaskmovps( + ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + } + } + + if (unroll_m >= 16) { + // Re-use ymm4 (VBIAS2) + if (i == 0) { + if (hasBias) { + if (isLoad1Unmasked) { + vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]); + } else { + vmaskmovps( + VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]); + } + } + } + vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA); + if (!isBeta0) { + if (isLoad2Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break; + case 1: + vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break; + case 4: + vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]); + break; + case 1: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmaskmovps(ymm0, VMASK, + ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]); + break; + case 4: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmaskmovps(ymm0, VMASK, + ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(Ymm(i + 10), ymm0, Ymm(i + 10)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 10), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10)); + } + if (isLoad2Unmasked) { + switch (i) { + case 0: + vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10)); + break; + case 1: + vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + case 3: + vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10)); + break; + case 4: + vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 1: + vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 3: + vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 4: + vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + } + } + } + if (i == 2) + add(CO1, rax); + } + if (unroll_n >= 4) { + add(CO2, rax); + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15)); + }; + + auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 16 rows, 6 columns at a time (handling tail + // cases appropriately). + // Masking is used for tail cases where M is not divisible by 8. + auto subloop = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + if (isTransA) { + do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked); + } + + Label subloop11, subloop11mask; + Label subloop20, subloop21, subloop22, subloop23; + Label subloop24, subloop25; + Label subloop30, subloop31, subloop32, subloop33; + Label subloop34, subloop35; + Label subloop98, subloop98mask; + Label subloop99, subloop99mask; + + mov(CO1, C); + lea(CO2, ptr[CO1 + LDC * 2]); + add(CO2, LDC); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, qword[B + LDB3]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(subloop20, T_NEAR); + } + align(16); + + if (!isTransA) { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } + } else { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(subloop20, T_NEAR); + align(16); + + L(subloop11); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + L(subloop20); + cmp(I, 1); + jne(subloop21, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop21); + cmp(I, 2); + jne(subloop22, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop22); + cmp(I, 3); + jne(subloop23, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop23); + cmp(I, 4); + jne(subloop24, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop24); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(subloop25, T_NEAR); + align(16); + + L(subloop31); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + L(subloop25); + cmp(I, 1); + jne(subloop32, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop32); + cmp(I, 2); + jne(subloop33, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop33); + cmp(I, 3); + jne(subloop34, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop34); + cmp(I, 4); + jne(subloop35, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop35); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + align(16); + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + sal(rax, 4); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + for (int i = 0; i < 8; i++) { + mov(dword[rsp + 88 + i * 4], i); + } + + if (isTransA && is_avx2) { + movq(xmm0, LDA); + vpbroadcastq(ymm1, xmm0); + vinsertf128(ymm0, ymm0, xmm0, 1); + vpermilpd(ymm0, ymm0, 5); + vpaddq(ymm1, ymm1, ymm1); + vperm2f128(ymm1, ymm1, ymm1, 8); + vpaddq(ymm0, ymm0, ymm1); + vmovups(STRIDE, ymm0); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, 0x1f); + mov(FLAG, rax); + + Label main0, main1, main2, main3, main999; + + cmp(M, UNROLL_M); + jl(main0, T_NEAR); + align(16); + + L(main1); + subloop(UNROLL_M, true, true); + sub(M, UNROLL_M); + cmp(M, UNROLL_M); + jge(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 0); + jle(main999, T_NEAR); + + if (UNROLL_M > 8) { + cmp(M, 8); + jle(main2, T_NEAR); + + sub(M, 8); + vbroadcastss(VMASK, M); + vpcmpgtd(VMASK, VMASK, MASK); + + subloop(16, true, false); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 8); + jne(main3, T_NEAR); + subloop(8, true, true); + jmp(main999, T_NEAR); + } + + align(16); + + L(main3); + vbroadcastss(VMASK, M); + if (is_avx2) { + vpcmpgtd(VMASK, VMASK, MASK); + } else { + auto xmask = Xmm(VMASK.getIdx()); + auto xmm_tmp = xmm4; + + vextractf128(xmm_tmp, VMASK, 1); + vpcmpgtd(xmask, xmask, MASK); + vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4 + vinsertf128(VMASK, VMASK, xmm_tmp, 1); + } + subloop(8, false, false); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032; + int BN = isTransA ? 96 : 48; + int BK = isTransB ? 96 : 256; + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 16 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp new file mode 100644 index 0000000000..aabf520a3c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX_GEMM_F32_HPP +#define JIT_AVX_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp new file mode 100644 index 0000000000..5147885a89 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp @@ -0,0 +1,346 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "gemm_utils_f32.hpp" +#include "ref_gemm_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace gemm_utils; + +namespace { + +template +void copy_A( + bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) { + for (int k = 0; k < K; k++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; + } + ws += unroll_factor::m; + } +} + +template +void kernel_mxn(int K, const data_t *A, const dim_t lda, + const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc, + const data_t alpha, const data_t beta) { + data_t c[unroll_factor::m * unroll_factor::n] = + { static_cast(0.) }; + for (int k = 0; k < K; k++) { + for (int j = 0; j < unroll_factor::n; j++) { + data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + data_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; + c[i + unroll_factor::m * j] += a * b; + } + } + } + for (int j = 0; j < unroll_factor::n; j++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + C[i + j * ldc] = (beta == static_cast(0.)) + ? alpha * c[i + unroll_factor::m * j] + : alpha * c[i + unroll_factor::m * j] + + beta * C[i + j * ldc]; + } + } +} + +template +void block_ker(const int M, const int N, const int K, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + data_t *C, const dim_t ldc, const data_t alpha, const data_t beta, + data_t *ws, bool do_copy) { + int Nu = rnd_dn(N, unroll_factor::n); + int Mu = rnd_dn(M, unroll_factor::m); + for (int i = 0; i < Mu; i += unroll_factor::m) { + for (int j = 0; j < Nu; j += unroll_factor::n) { + const data_t *b = isTransB ? &B[j] : &B[j * ldb]; + const data_t *a = isTransA ? &A[i * lda] : &A[i]; + if (do_copy) { + if (j == 0) { + copy_A(isTransA, K, a, lda, ws); + } + kernel_mxn( + K, ws, unroll_factor::m, b, ldb, + &C[i + j * ldc], ldc, alpha, beta); + } else { + kernel_mxn( + K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); + } + } + } + // tail processing + for (int i = 0; i < M; i++) { + for (int j = Nu; j < N; j++) { + data_t c = beta == static_cast(0.) + ? static_cast(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } + for (int i = Mu; i < M; i++) { + for (int j = 0; j < Nu; j++) { + data_t c = beta == static_cast(0.) + ? static_cast(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } +} + +template +void gemm_ithr(const int M, const int N, const int K, const data_t alpha, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + const data_t beta, data_t *C, const dim_t ldc, bool do_copy, + data_t *ws) { + constexpr int BM = gemm_traits::BM; + constexpr int BN = gemm_traits::BN; + constexpr int BK = gemm_traits::BK; + + const data_t *curA; + const data_t *curB; + data_t *curC; + + if ((M <= 0) || (N <= 0)) + return; + + if ((K <= 0) || (alpha == static_cast(0))) { + dim_t MN = N * M; + if (beta == static_cast(0.)) { + for (dim_t j = 0; j < MN; j++) + C[j] = static_cast(0.); + } else if (beta != static_cast(1.)) { + for (dim_t j = 0; j < MN; j++) + C[j] *= beta; + } + return; + } + + for (int Bk = 0; Bk < K; Bk += BK) { + int kb = nstl::min(K - Bk, BK); + for (int Bm = 0; Bm < M; Bm += BM) { + int mb = nstl::min(M - Bm, BM); + for (int Bn = 0; Bn < N; Bn += BN) { + int nb = nstl::min(N - Bn, BN); + curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; + curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; + curC = C + Bm + Bn * ldc; + if (Bk == 0) { + block_ker(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, beta, ws, do_copy); + } else { + block_ker(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, static_cast(1.0), + ws, do_copy); + } + } + } + } +} + +} + +template +mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, const int *M_, + const int *N_, const int *K_, const data_t *alpha_, const data_t *A, + const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_, + data_t *C, const int *ldc_, const data_t *bias) { + + bool isTransA = (*transa_ == 'T' || *transa_ == 't'); + bool isTransB = (*transb_ == 'T' || *transb_ == 't'); + const int M = *M_, N = *N_, K = *K_; + const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; + const data_t alpha = *alpha_, beta = *beta_; + + int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads(); + int nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + // thread balancing over M, N, K & size of blocking dimensions + calc_nthr_nocopy_avx( + M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + data_t *c_buffers = nullptr; + data_t *ws_buffers = nullptr; + if (nthr_k > 1) { + c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(data_t), PAGE_4K); + if (!c_buffers) { + nthr_k = 1; + KB = K; + } + } + + bool do_copy = (NB / unroll_factor::n > 3); + const int nthr_mn = nthr_m * nthr_n; + const int nthr = nthr_mn * nthr_k; + const size_t ws_elems_per_thr = K * unroll_factor::m; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K); + if (do_copy) { + ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K); + if (!ws_buffers) + do_copy = false; + } + + auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N, + int ithr) { + from = NB * (ithr); + to = NB * (ithr + 1); + if (to > N) + to = N; + myN = to - from; + }; + + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_n = ithr_mn / nthr_m; + int ithr_k = ithr / nthr_mn; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + data_t *ws = do_copy + ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t) + : nullptr; + + int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, + k_from = 0, k_to = 0, myK = 0; + + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(k_from, k_to, myK, KB, K, ithr_k); + + if (myM > 0 && myN > 0) { + data_t myBeta, *myC; + dim_t ld; + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0f; + ld = MB; + } + const data_t *myA = isTransA + ? &(A[k_from + m_from * lda]) + : &(A[m_from + k_from * lda]); + const data_t *myB = isTransB + ? &(B[n_from + k_from * ldb]) + : &(B[k_from + n_from * ldb]); + + if (!isTransA) { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } else { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } + } + }); + + if (nthr_k > 1) { + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_k = ithr / nthr_mn; + int ithr_n = ithr_mn / nthr_m; + + int n_from = 0, n_to = 0, myN = 0; + int m_from = 0, m_to = 0, myM = 0; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + + // sum matrices partitioned along K dimension + int offset = 0, block = 0; + gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset, + &block); + for (int ik = 1; ik < nthr_k; ++ik) { + data_t *myC = c_buffers + + MB * ((dim_t)NB * (cbase + ik - 1) + offset); + + gemm_utils::sum_two_matrices(myM, block, myC, MB, + &C[m_from + (n_from + offset) * ldc], ldc); + } + }); + } + + if (bias) { + parallel_nd(N, M, [&](int i, int j) { + C[i*ldc + j] += bias[j]; + }); + } + + free(ws_buffers); + free(c_buffers); + + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const float *alpha_, + const float *A, const int *lda_, const float *B, const int *ldb_, + const float *beta_, float *C, const int *ldc_, const float *bias); + +template mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const double *alpha_, + const double *A, const int *lda_, const double *B, const int *ldb_, + const double *beta_, double *C, const int *ldc_, const double *bias); +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp new file mode 100644 index 0000000000..7c90ba6277 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_F32_HPP +#define REF_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M, + const int *N, const int *K, const data_t *alpha, const data_t *A, + const int *lda, const data_t *B, const int *ldb, const data_t *beta, + data_t *C, const int *ldc, const data_t *bias); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp new file mode 100644 index 0000000000..3dbe07d743 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "mkldnn_traits.hpp" +#include "nstl.hpp" + +#include "jit_generator.hpp" + +#include "gemm.hpp" + +#include "f32/jit_avx512_common_gemm_f32.hpp" +#include "f32/jit_avx_gemm_f32.hpp" +#include "f32/ref_gemm_f32.hpp" + +#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp" +#include "s8x8s32/simple_gemm_s8s8s32.hpp" +#include "s8x8s32/ref_gemm_s8x8s32.hpp" + +#include "os_blas.hpp" + +/* USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes use Intel(R) MKL CBLAS + * yes no use jit + * no yes system-dependent CBLAS + * no no use jit + */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t check_gemm_input(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const int *lda, + const int *ldb, const int *ldc, const float *alpha, const float *beta, + const bool with_bias) { + if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta)) + return mkldnn_invalid_arguments; + if (with_bias && *beta != 0) + return mkldnn_unimplemented; + bool consistency = true + && utils::one_of(*transa, 'T', 't', 'N', 'n') + && utils::one_of(*transb, 'T', 't', 'N', 'n') + && *M >= 0 + && *N >= 0 + && *K >= 0; + + if (!consistency) + return mkldnn_invalid_arguments; + bool isTransA = utils::one_of(*transa, 'T', 't'); + bool isTransB = utils::one_of(*transb, 'T', 't'); + int nrowA = isTransA ? *K : *M; + int nrowB = isTransB ? *N : *K; + consistency = true + && *lda >= nstl::max(1, nrowA) + && *ldb >= nstl::max(1, nrowB) + && *ldc >= nstl::max(1, *M); + if (!consistency) + return mkldnn_invalid_arguments; + + return mkldnn_success; +} + +mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc, + const char *transa, const char *transb, const int *M, const int *N, + const int *K, const int *lda, const int *ldb, const int *ldc, + const float *alpha, const float *beta, const bool with_bias) { + if (offsetc == nullptr) + return mkldnn_invalid_arguments; + if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) + return mkldnn_invalid_arguments; + + return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, + beta, with_bias); +} + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias, const bool force_jit_gemm) { + mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K, + lda, ldb, ldc, alpha, beta, bias != nullptr); + if (status != mkldnn_success) + return status; + +#ifdef USE_CBLAS + if (!force_jit_gemm) { + bool trA = *transa == 't' || *transa == 'T'; + bool trB = *transb == 't' || *transb == 'T'; + CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; + cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, + *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); + + if (bias) { + // Add bias if necessary (bias is applied to columns of C) + cblas_int incx = 1, incy = 1; + parallel_nd(*N, [&](int n) { + ptrdiff_t offset = (ptrdiff_t)n * (*ldc); + cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy); + }); + } + return mkldnn_success; + } +#endif + + if (mayiuse(avx512_common)) + return jit_avx512_common_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else if (mayiuse(avx)) + return jit_avx_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else + return ref_gemm(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); +} + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, + M, N, K, LDA, LDB, LDC, alpha, beta, false); + if (status != mkldnn_success) + return status; + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + +#if USE_MKL_IGEMM + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + if (data_traits::data_type == data_type::u8) { + CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; + CBLAS_OFFSET Cblas_offsetc = + OCisR + ? CblasRowOffset + : OCisC + ? CblasColOffset + : CblasFixOffset; + cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, + *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo, + *beta, C, *LDC, co); + return mkldnn_success; + } else { + assert(data_traits::data_type == data_type::s8); + // TODO CBLAS implementation of gemm_s8s8s32 goes here. + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if (utils::everyone_is(0, *ao, *bo)) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#else + cpu_isa_t isa = isa_any; + if (mayiuse(avx512_core_vnni)) { + isa = avx512_core_vnni; + } else if (mayiuse(avx512_core)) { + isa = avx512_core; + } + + if (data_traits::data_type == data_type::u8) { + switch (isa) { + case avx512_core: + case avx512_core_vnni: + return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta, + C, LDC, co); + default: + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } else { + assert(data_traits::data_type == data_type::s8); + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni)) + && *ao == 0 && *bo == 0) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#endif +} + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::cpu; + +mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb, + const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha, + const float *A, const int64_t *lda, const float *B, const int64_t *ldb, + const float *beta, float *C, const int64_t *ldc) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32); +} + +mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} + +mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp new file mode 100644 index 0000000000..dc15ff7130 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp @@ -0,0 +1,58 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_HPP +#define GEMM_HPP + +#include "mkldnn_types.h" +#include "os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias = nullptr, bool force_jit_gemm = false); + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, + const b_dt *B, const int *ldb, const int8_t *bo, const float *beta, + int32_t *c, const int *ldc, const int32_t *co); + +#ifdef USE_CBLAS +#define GEMM_IMPL_STR "gemm:blas" +#else +#define GEMM_IMPL_STR "gemm:jit" +#endif + +#if USE_MKL_IGEMM +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas" +#else +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit" +#endif + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp new file mode 100644 index 0000000000..4d34ede0bd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef OS_BLAS_HPP +#define OS_BLAS_HPP + +/** \file + * Common stuff respecting USE_MKL and USE_CBLAS compile flags + * + * USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS + * yes no jit calls OK; assert if cblas is ever called + * no yes system-dependent CBLAS + * no no gemm convolution (or other blas) N/A; create stubs + */ + +#if defined(USE_MKL) + +#include "mkl_version.h" + +#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001) +#define USE_MKL_IGEMM \ + (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628) + +#include "mkl_cblas.h" +#if !defined(USE_CBLAS) +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif + +#else /* defined(USE_MKL) */ + +#define USE_MKL_PACKED_GEMM 0 +#define USE_MKL_IGEMM 0 + +#if defined(_SX) +/* TODO: _SX should also define USE_CBLAS in case the later is available */ +extern "C" { +#include "cblas.h" // CHECK: does SX also have a fortran API sgemm? +} + +#elif defined(USE_CBLAS) +#include "cblas.h" // Maybe a system/cmake cblas works for you? +#else +/* put the stubs to make a code compilable but not workable */ +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif /* defined(_SX) */ + +#endif /* defined(USE_MKL) */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +#if defined(USE_MKL) && defined(USE_CBLAS) +typedef MKL_INT cblas_int; + +#elif defined(USE_CBLAS) +typedef int cblas_int; + +#if defined(_SX) +/* this cblas.h is peculiar... */ +typedef CBLAS_ORDER CBLAS_LAYOUT; +#endif +#endif + +} +} +} + +#endif /* OS_BLAS_HPP */ + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp new file mode 100644 index 0000000000..dde72f4a17 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp @@ -0,0 +1,206 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_H +#define COMMON_H + +#define GEMM_CODE_SIZE (4096L * 32) + +#define AVX512_UNROLL_M 48 +#define AVX512_UNROLL_N 8 +#define AVX512_UNROLL_K 1 +#define AVX512_BM 9984 +#define AVX512_BN 384 +#define AVX512_BK 768 +#define AVX512_BK_VNNI 1536 +#define AVX512_BK_TRADITIONAL 384 +#define AVX512_BLOCKING_SMALL_K 48 +#define AVX512_BN_SMALL_K 24 + + +#define PAGESIZE 4096 + +#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE +#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +enum { + PARTITION_1D_ROW, + PARTITION_1D_COL, + PARTITION_2D_COL_MAJOR, + PARTITION_2D = PARTITION_2D_COL_MAJOR, +}; + +enum { + COPY_NONE, + COPY_A, +}; + +enum { + NO_OFFSET, + FIX_OFFSET, + COL_OFFSET, + ROW_OFFSET, +}; + +// Alias for any dimension related variable. +typedef long long int dim_t; + +typedef struct { + // Interface arguments. + int transa, transb, offsetc; + dim_t m, n, k; + dim_t lda, ldb, ldc; + const int8_t *a; + const uint8_t *b; + int32_t *c; + const float *alpha, *beta; + + int8_t ao, bo; + const int32_t *co; + + // Kernel parameters. + dim_t um, un, uk, bm, bn, bk; + dim_t bn_small_k, bk_traditional, blocking_small_k; + + int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + // Gemv kernels + void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + // Gemv parameters + int swap; + +} blas_t; + + +class jit_avx512_core_u8_copy_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern); + + public: + jit_avx512_core_u8_copy_an_kern(); +}; + +class jit_avx512_core_u8_copy_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern); + + public: + jit_avx512_core_u8_copy_at_kern(); +}; + +class jit_avx512_core_u8_copy_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern); + + public: + jit_avx512_core_u8_copy_bn_kern(); +}; + +class jit_avx512_core_u8_copy_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern); + + public: + jit_avx512_core_u8_copy_bt_kern(); +}; + +class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern); + + public: + jit_avx512_core_u8_copy_sum_an_kern(); +}; + +class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern); + + public: + jit_avx512_core_u8_copy_sum_at_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern); + + public: + jit_avx512_core_u8_copy_sum_bn_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern); + + public: + jit_avx512_core_u8_copy_sum_bt_kern(); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp new file mode 100644 index 0000000000..db9dd9ef97 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp @@ -0,0 +1,28 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg); +int gemv_threading_driver(blas_t *arg); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp new file mode 100644 index 0000000000..e4b8e1cde2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp @@ -0,0 +1,1409 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "common.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_gemm_s8u8s32.hpp" +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" +#include "gemv.hpp" + +#if defined(_MSC_VER) +#include +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef struct { + int nthrs_m, nthrs_n; + int partition; + int copy_type; +} blas_thread_t; + +static inline void round_to_nearest(int32_t *rounded_val, double fp_val) { + if (fp_val >= 0.) { + fp_val += 0.5; + if (fp_val > INT32_MAX) { + fp_val = INT32_MAX; + } + } else { + fp_val -= 0.5; + if (fp_val < INT32_MIN) { + fp_val = INT32_MIN; + } + } + *rounded_val = (int32_t) fp_val; +} + +static inline void add_results(const dim_t m, const dim_t n, const dim_t k, + const float alpha, const float beta, const int32_t *c_partial_sum, + const dim_t ldcp, int32_t *c_data, const dim_t ldc, + const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao, + const int8_t bo, const int32_t *co, const int offsetc) +{ + for (dim_t j = 0; j < n; ++j) { + for (dim_t i = 0; i < m; ++i) { + int32_t ctemp = c_partial_sum[i + j * ldcp]; + + if (alpha == 1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float += (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else if (alpha == -1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = -ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float -= (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else { + if (beta == 0.0f) { + double c_float = alpha * (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + double c_float = alpha * (double) ctemp + + beta * (double) c_data[i + j * ldc]; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } + + if (offsetc == FIX_OFFSET) { + c_data[i + j * ldc] += co[0]; + } else if (offsetc == ROW_OFFSET) { + c_data[i + j * ldc] += co[j]; + } else if (offsetc == COL_OFFSET) { + c_data[i + j * ldc] += co[i]; + } + } + } +} + +// TODO Find a better place for those functions. +static inline dim_t ld_padd(const dim_t x) +{ + return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t))) + * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t)); +} + +void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, float beta, int32_t *c, + const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum, + const int32_t *co, const int offsetc, const blas_t *arg) +{ + int8_t ao = arg->ao; + int8_t bo = arg->bo; + int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0]; + + // Since m and n are limited by blocking, stack overflow may not happen; + // it's up to 32kB +#if !defined(_MSC_VER) + int32_t col_offset[m]; + int32_t row_offset[n]; +#else + int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m); + int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n); +#endif + + int col_req = 0; + int row_req = 0; + + if ((bo != 0) || (offsetc == COL_OFFSET)) + col_req = 1; + if ((ao != 0) || (offsetc == ROW_OFFSET)) + row_req = 1; + + // It needs one of colum or row offsets, but it doesn't need both + if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) { + if ((col_req == 0) && (row_req == 0)) { + if (m <= n) { + col_req = 1; + } else { + row_req = 1; + } + } + } + + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] = 0; + + if (offsetc == COL_OFFSET) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co[i]; + } + + if (bo != 0) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += bo * a_row_sum[i]; + } + } + + if (row_req) { + for (dim_t i = 0; i < n; i++) + row_offset[i] = 0; + + if (offsetc == ROW_OFFSET) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co[i]; + } + + if (ao != 0) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += ao * b_col_sum[i]; + } + } + + if ((offsetc == FIX_OFFSET) && (co_0 != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co_0; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co_0; + } + } + + if ((ao != 0) && (bo != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += (int32_t) k * ao * bo; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += (int32_t) k * ao * bo; + } + } + + if (col_req == 0) { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } else { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } +} + +static inline void *align(void *ptr, size_t alignment) +{ + return (void *) utils::rnd_up((uintptr_t) ptr, alignment); +} + +static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co, + const blas_t *arg) +{ + dim_t lda = arg->lda; + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + float beta = *arg->beta; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m_padd); + + dim_t strideAm = (arg->transa == 0)? 1 : lda; + dim_t strideAn = (arg->transa != 0)? 1 : lda; + dim_t strideBm = (arg->transb == 0)? 1 : ldb; + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t a_buf_nelems = m_padd * k_padd; + size_t b_buf_nelems = k_padd * n_padd; + size_t a_row_sum_nelems = m_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K + + b_buf_nelems * sizeof(*b) + PAGE_4K + + a_row_sum_nelems * sizeof(*c) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems, + PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + float beta_saved = beta; + + int a_block_copied = 0; + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first time + if (Bk == 0) + beta = beta_saved; + else + beta = 1.0f; + + // Apply C offset when to the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn; + arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL, + NULL, b_col_sum); + + dim_t sizeUM = 0; + for (dim_t Um = 0; Um < sizeM; Um += sizeUM) { + sizeUM = sizeM - Um; + if (sizeUM > arg->um) + sizeUM = arg->um; + + /* + * Use the whole A buffer only if we have multiple B blocks + * for k-dimension, otherwise we are wasting cache to store + * B and C blocks. + */ + dim_t Um_forA = 0; + if (sizeN < n) + Um_forA = Um; + + const int8_t *a_block = a + (Bm + Um) * strideAm + + Bk * strideAn; + if (!a_block_copied) { + arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL, + bufferA + Um_forA * sizeK, NULL, NULL, + a_row_sum + Um_forA); + } + + int32_t *c_block = c + (Bm + Um) + Bn * ldc; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm + Um; + } + if (need_c_buffer) { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, 0.0f, + bufferC + Um, ldc_buf, a_row_sum + Um_forA, + b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta + // and offsets. + add_results(sizeUM, sizeN, sizeK, alpha, beta, + bufferC + Um, ldc_buf, c_block, ldc, + a_row_sum + Um_forA, b_col_sum, ao, bo, + co + co_stride, offsetc); + } else { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, beta, + c_block, ldc, a_row_sum + Um_forA, b_col_sum, + co + co_stride, offsetc, arg); + } + } + a_block_copied = 1; + } + a_block_copied = 0; + } + } + + free(mem); + + return 0; +} + +static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n, + const dim_t k, const int8_t *bufferA, const uint8_t *b, + const float beta, int32_t *c, const int offsetc, const int32_t *co, + const int32_t *a_row_sum, const blas_t *arg) +{ + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m); + + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t b_buf_nelems = k * n_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + // Implement the kernel here. + const uint8_t *b_block = b + Bn * strideBn; + arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL, + b_col_sum); + + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + int32_t *c_block = c + Bn * ldc; + if (need_c_buffer) { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC, + ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta and offsets. + add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block, + ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride, + offsetc); + } else { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block, + ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg); + } + } + + free(mem); + + return 0; + +} + +#define N2D_MAX_AVX512 384 +#define M2D_MIN_AVX512 384 +#define VECLEN 16 +#define NCONS 1 +static inline void set_thread_opts_avx512(int *p_nthrs, + blas_thread_t *thread_info, const blas_t *arg) +{ + int nthrs = *p_nthrs; + dim_t m = arg->m; + dim_t n = arg->n; + + thread_info->nthrs_m = 0; + thread_info->nthrs_n = 0; + thread_info->copy_type = COPY_NONE; // By default don't do parallel copy. + + int condition_2D_bsrc = -1; + if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) { + condition_2D_bsrc = 1; + } else { + condition_2D_bsrc = 0; + } + + int condition_1D_copya = 0; + if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + // If offset is non-zero, we need to keep 1D_copya to reduce update overhead + if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0 + || arg->offsetc != FIX_OFFSET) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + if (condition_2D_bsrc == 1) { + int nthrs_m = 1; + int nthrs_n = nthrs; + + while ((nthrs_n % 2 == 0) && + (n / nthrs > N2D_MAX_AVX512 || + n / nthrs_n <= N2D_MAX_AVX512 / 2) && + (m / nthrs_m >= 2 * M2D_MIN_AVX512) && + (nthrs_m < 4)) { + nthrs_m *= 2; + nthrs_n /= 2; + } + + thread_info->nthrs_m = nthrs_m; + thread_info->nthrs_n = nthrs_n; + thread_info->partition = PARTITION_2D; + + // Reset the total number of threads that will be used. + *p_nthrs = nthrs_m * nthrs_n; + + } else if (condition_1D_copya && mkldnn_thr_syncable()) { + // Use parallel copy A algorithm + thread_info->copy_type = COPY_A; + thread_info->partition = PARTITION_1D_COL; + } else { + if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) { + thread_info->partition = PARTITION_1D_ROW; + } else { + thread_info->partition = PARTITION_1D_COL; + } + } +} +#undef N2D_MAX_AVX512 +#undef M2D_MIN_AVX512 +#undef VECLEN +#undef NCONS + +static inline void partition_1d(const int ithr, const int nthrs, const dim_t n, + dim_t *t_offset, dim_t *t_block) +{ + dim_t band = n / nthrs; + + dim_t tail = n - (nthrs - 1) * band; + if (tail > (band + 1)) + band++; + tail = n - (nthrs - 1) * band; + + if (ithr < (nthrs - 1)) + *t_block = band; + else + *t_block = tail; + + *t_offset = ithr * band; + + if (*t_offset >= n) { + *t_block = 0; + *t_offset = 0; + } else if ((*t_offset + *t_block) > n) { + *t_block = n - *t_offset; + } +} + +static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i, + const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m, + const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp, + dim_t *p_n_band) +{ + dim_t m_disp = 0, n_disp = 0; + dim_t m_band = 0, n_band = 0; + + int mdiv = nthrs_m; + int ndiv = nthrs_n; + + dim_t m_bandt = m / mdiv; /* size per thread */ + dim_t n_bandt = n / ndiv; /* size per thread */ + int firstmgroup = mdiv - 1; + int firstngroup = ndiv - 1; + dim_t firstmval = m_bandt; + dim_t firstnval = n_bandt; + + int mthr_used = mdiv; + if (m - (mdiv - 1) * m_bandt > m_bandt + 1) { + if (m - (mdiv - 1) * m_bandt > mdiv) + ++m_bandt; + + firstmval = m_bandt + 1; + mthr_used = (int) (m / firstmval); + + if (mthr_used * firstmval < m) + ++mthr_used; + + firstmgroup = mthr_used - 1; + } + + int nthr_used = ndiv; + if (n - (ndiv - 1) * n_bandt > n_bandt + 1) { + firstnval = n_bandt + 1; + nthr_used = (int) (n / firstnval); + + if (nthr_used * firstnval < n) + ++nthr_used; + + firstngroup = nthr_used - 1; + } + + *nthrs = mthr_used * nthr_used; + + if (ithr < *nthrs) { + if (ithr_i < firstmgroup) { + m_band = firstmval; + m_disp = ithr_i * firstmval; + } else if (ithr_i <= mthr_used - 2) { + m_band = m_bandt; + m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt; + } else { + m_disp = firstmgroup * firstmval + + (mthr_used - 1 - firstmgroup) * m_bandt; + m_band = nstl::max(0LL, m - m_disp); + } + + if (ithr_j < firstngroup) { + n_band = firstnval; + n_disp = ithr_j * firstnval; + } else if (ithr_j <= nthr_used - 2) { + n_band = n_bandt; + n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt; + } else { + n_disp = firstngroup * firstnval + + (nthr_used - 1 - firstngroup) * n_bandt; + n_band = nstl::max(0LL, n - n_disp); + } + m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL); + n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL); + } + + if (ithr < *nthrs) { + *p_m_disp = m_disp; + *p_n_disp = n_disp; + *p_m_band = m_band; + *p_n_band = n_band; + } else { + *p_m_disp = 0; + *p_n_disp = 0; + *p_m_band = 0; + *p_n_band = 0; + } + + return; +} + +static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m, + dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c, + const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg) +{ + dim_t strideAm = (arg->transa == 0)? 1 : arg->lda; + dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb; + int offsetc = arg->offsetc; + + switch (thread_info->partition) { + case PARTITION_1D_ROW: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->m, &offset, &block); + + *m = block; + *n = arg->n; + *k = arg->k; + + // Set matrix A. + *a = arg->a + offset * strideAm; + + // Set matrix B. + *b = arg->b; + + // Set matrix C. + *c = arg->c + offset; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = offset; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_1D_COL: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->n, &offset, &block); + + *m = arg->m; + *n = block; + *k = arg->k; + + // Set matrix A. + *a = arg->a; + + // Set matrix B. + *b = arg->b + offset * strideBn; + + // Set matrix C. + *c = arg->c + offset * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = offset; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_2D_COL_MAJOR: + { + int nthrs_m = thread_info->nthrs_m; + int nthrs_n = thread_info->nthrs_n; + int ithr_i = ithr % nthrs_m; + int ithr_j = ithr / nthrs_m; + + dim_t m_disp = 0; + dim_t m_band = 0; + dim_t n_disp = 0; + dim_t n_band = 0; + + partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n, + arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band); + + *m = m_band; + *n = n_band; + *k = arg->k; + + // Set matrix A. + *a = arg->a + m_disp * strideAm; + + // Set matrix B. + *b = arg->b + n_disp * strideBn; + + // Set matrix C. + *c = arg->c + m_disp + n_disp * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = n_disp; + } else if (offsetc == COL_OFFSET) { + co_stride = m_disp; + } + *co = arg->co + co_stride; + break; + } + } +} + +#define MULTIPLIER 10 +static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m, + const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b, + int32_t *c, const int32_t *co, const blas_t *arg, + char **p_shared_mem) +{ + const dim_t lda = arg->lda; + const dim_t ldb = arg->ldb; + const dim_t strideAm = (arg->transa == 0)? 1 : lda; + const dim_t strideAn = (arg->transa != 0)? 1 : lda; + const dim_t strideBm = (arg->transb == 0)? 1 : ldb; + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs; + if (m_padd > m) { + m_padd = utils::rnd_up(m, arg->um); + } + + size_t a_buf_nelems = m_padd * k_padd; + + // Allocate shared memory for A and its row sum buffers in master thread. + if (ithr == 0) { // If thread master + size_t a_row_sum_nelems = m_padd; + + size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K) + + a_row_sum_nelems * sizeof(*c) + PAGE_4K; + + *p_shared_mem = (char *) malloc(mem_size, 128); + + } + mkldnn_thr_barrier(); + + char *mem = *p_shared_mem; + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K); + + if (!mem) { + return -1; + } + + int result = 0; // Return status + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first term of partial sum. + float beta = 1.0f; + if (Bk == 0) + beta = *(arg->beta); + + // Apply C offset for the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + if (ithr < nthrs) { + dim_t band = (sizeM + nthrs - 1) / nthrs; + band = utils::rnd_up(band, arg->um); + + dim_t offset = band * ithr; + + // If offset is too large don't use that thread for copying. + if (offset >= sizeM) { + offset = 0; + band = 0; + } + + // Handle the tail of the copy. + if (offset + band > sizeM) { + band = sizeM - offset; + } + + if (band > 0) { + const int8_t *a_block = a + (Bm + offset) * strideAm + + Bk * strideAn; + arg->copyA(&sizeK, &band, a_block, &lda, NULL, + bufferA + offset * sizeK, NULL, NULL, + a_row_sum + offset); + } + } + mkldnn_thr_barrier(); // Wait for finishing parallel copy. + + const uint8_t *b_block = b + Bk * strideBm; + int32_t *c_block = c + Bm; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm; + } + + result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK, + bufferA, b_block, beta, c_block, offsetc, co + co_stride, + a_row_sum, arg); + + mkldnn_thr_barrier(); // Wait for kernel computations to finish. + } + } + + // Free memory allocated in master thread + if (ithr == 0) { + free(mem); + } + + return result; +} +#undef MULTIPLIER + +static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k, + double fp_per_cycle, int *nthrs) +{ + double omp_overhead_small_core = 3.0e+3; + double omp_intercept_big_core = 4.0e+3; + double omp_slope_big_core = 5.0e+2; + + double gemm_cycles = 8.0 * m * n * k / fp_per_cycle; + + int i = *nthrs; + + // Use a different model for omp overheads if nthrs is <= 4 + if (*nthrs <= 4 && omp_overhead_small_core > 0) { + double omp_cycles = omp_overhead_small_core; + if (gemm_cycles < omp_cycles) { + *nthrs = 1; + return; + } else { + while (i > 1) { + if (omp_cycles * i < gemm_cycles * (i - 1)) break; + --i; + } + } + } else { + if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) { + *nthrs = 1; + return; + } + + // adaptive decrement to march faster· + while (i > 1) { + double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core; + if (omp_cycles * i < gemm_cycles * (i - 1)) + break; + + if (i < 10) + i -= 2; + else if (i < 30) + i -= 4; + else + i -= 8; + } + } + + if (i < 1) + i = 1; + + *nthrs = i; +} + +#define CACHE_LINE_SIZE 64 +static int gemm_threading_driver(blas_t *arg) +{ + if ((arg->m <= 0) || (arg->n <= 0)) + return mkldnn_success; + + if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) { + return mkldnn_success; + } + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr); + + if (nthr == 1) { + return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b, + arg->c, arg->co, arg); + } + + int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE, + PAGE_4K); + + if (!results) { + return -1; + } + + for (int i = 0; i < nthr; i++) { + results[i * CACHE_LINE_SIZE] = 0; // Initialize to success + } + + char *shared_mem = NULL; + + parallel(nthr, [&](const int ithr, const int nthr) { + int nthrs = nthr; + if (nthrs == 1) { + results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, + arg->b, arg->c, arg->co, arg); + } else { + blas_thread_t thread_info; + set_thread_opts_avx512(&nthrs, &thread_info, arg); + + const int8_t *a = NULL; + const uint8_t *b = NULL; + int32_t *c = NULL; + const int32_t *co = NULL; + dim_t m = -1; + dim_t n = -1; + dim_t k = -1; + decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co, + &thread_info, arg); + + if (ithr < nthrs) { + switch (thread_info.copy_type) { + case COPY_A: + results[ithr * CACHE_LINE_SIZE] = + parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg, + &shared_mem); + break; + + default: + case COPY_NONE: + results[ithr * CACHE_LINE_SIZE] = + gemm_kernel_driver(m, n, k, a, b, c, co, arg); + break; + } + } + } + }); + + int result = 0; // Initialize to success + for (int i = 0; i < nthr; i++) { + if (results[i] != 0) { + result = results[i * CACHE_LINE_SIZE]; + break; + } + } + + free(results); + + return result; +} +#undef CACHE_LINE_SIZE + +static jit_avx512_core_u8_copy_an_kern *copy_an; +static jit_avx512_core_u8_copy_at_kern *copy_at; +static jit_avx512_core_u8_copy_bn_kern *copy_bn; +static jit_avx512_core_u8_copy_bt_kern *copy_bt; +static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an; +static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at; +static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn; +static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt; +static jit_avx512_core_gemm_s8u8s32_kern *kernel; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_c; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel; + +static void jit_init(blas_t *arg) +{ + static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + if (mayiuse(avx512_core_vnni)) { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK_VNNI; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } else { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } + + static std::once_flag initialized; + std::call_once(initialized, []{ + + copy_an = new jit_avx512_core_u8_copy_an_kern(); + copy_at = new jit_avx512_core_u8_copy_at_kern(); + copy_bn = new jit_avx512_core_u8_copy_bn_kern(); + copy_bt = new jit_avx512_core_u8_copy_bt_kern(); + + copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern(); + copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern(); + copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern(); + copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern(); + + kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false); + kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true); + kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true); + kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false); + kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false); + kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true); + kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true); + kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false); + + gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + + + copyAn = copy_an->getCode(); + + copyAt = copy_at->getCode(); + + copyBn = copy_bn->getCode(); + + copyBt = copy_bt->getCode(); + + copySumAn = copy_sum_an->getCode(); + + copySumAt = copy_sum_at->getCode(); + + copySumBn = copy_sum_bn->getCode(); + + copySumBt = copy_sum_bt->getCode(); + + kern = kernel->getCode(); + + kern_b = kernel_b->getCode(); + + kern_r = kernel_r->getCode(); + + kern_c = kernel_c->getCode(); + + kern_b0 = kernel_b0->getCode(); + + kern_b0_b = kernel_b0_b->getCode(); + + kern_b0_r = kernel_b0_r->getCode(); + + kern_b0_c = kernel_b0_c->getCode(); + + gemv_s8u8s32_kern = + gemv_s8u8s32_kernel -> generate + (mayiuse(avx512_core_vnni)); + gemv_u8s8s32_kern = + gemv_u8s8s32_kernel -> generate + (mayiuse(avx512_core_vnni)); + }); + + if (arg->bo == 0) { // No need to compute A row sum if bo is zero + if (arg->transa == 0) { + arg->copyA = copyAn; + } else { + arg->copyA = copyAt; + } + } else { + if (arg->transa == 0) { + arg->copyA = copySumAn; + } else { + arg->copyA = copySumAt; + } + } + + if (arg->ao == 0) { // No need to compute B column sum if ao is zero + if (arg->transb == 0) { + arg->copyB = copyBn; + } else { + arg->copyB = copyBt; + } + } else { + if (arg->transb == 0) { + arg->copyB = copySumBn; + } else { + arg->copyB = copySumBt; + } + } + + arg->kernel = kern; + arg->kernel_b = kern_b; + arg->kernel_r = kern_r; + arg->kernel_c = kern_c; + arg->kernel_b0 = kern_b0; + arg->kernel_b0_b = kern_b0_b; + arg->kernel_b0_r = kern_b0_r; + arg->kernel_b0_c = kern_b0_c; + arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern; + arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern; +} + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) +{ + char transa = *transA; + char transb = *transB; + char offsetc = *offsetC; + + blas_t args; + + // Initialize blas structure + args.m = *m; + args.n = *n; + args.k = *k; + args.alpha = alpha; + args.a = a; + args.lda = *lda; + args.b = b; + args.ldb = *ldb; + args.beta = beta; + args.c = c; + args.ldc = *ldc; + args.transa = (transa == 'N' || transa == 'n') ? 0 : 1; + args.transb = (transb == 'N' || transb == 'n') ? 0 : 1; + args.um = 0; + args.un = 0; + args.bm = 0; + args.bn = 0; + args.bk = 0; + args.copyA = NULL; + args.copyB = NULL; + args.kernel = NULL; + args.kernel_b0 = NULL; + args.ao = *oa; + args.bo = *ob; + args.co = oc; + + if (offsetc == 'F' || offsetc == 'f') { + args.offsetc = FIX_OFFSET; + } else if (offsetc == 'R' || offsetc == 'r') { + args.offsetc = ROW_OFFSET; + } else { // offsetc == 'C' || offsetc == 'c' + args.offsetc = COL_OFFSET; + } + + jit_init(&args); + int result = gemm_threading_driver(&args); + + return (result < 0) ? mkldnn_out_of_memory : mkldnn_success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp new file mode 100644 index 0000000000..b2e2902a12 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP +#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP + +#include +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); + +} +} +} + +#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp new file mode 100644 index 0000000000..57554a1852 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp @@ -0,0 +1,539 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" + + +#ifdef _WIN32 +static const bool is_windows = 1; +#else +static const bool is_windows = 0; +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + + + + +// Convert between vector register lengths. +static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); } +static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); } + +// Load from or store to C. +void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst, + const Xbyak::Address &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(make_ymm(dst), src); break; + case 4: vmovups(make_xmm(dst), src); break; + case 2: vmovlps(make_xmm(dst), src); break; + case 1: vmovss(make_xmm(dst), src); break; + } +} +void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst, + const Xbyak::Xmm &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(dst, make_ymm(src)); break; + case 4: vmovups(dst, make_xmm(src)); break; + case 2: vmovsd(dst, make_xmm(src)); break; + case 1: vmovss(dst, make_xmm(src)); break; + } +} + +// Perform length-4 dot product accumulations of unsigned and signed bytes +// in parallel. +// Use vpdpbusd if VNNI available, otherwise emulate. +void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst, + const Xmm &src1, const Xmm &src2) +{ + if (vnni) + vpdpbusd(dst, src1, src2); + else { + vpmaddubsw(dp_scratch, src1, src2); + vpmaddwd(dp_scratch, ones, dp_scratch); + vpaddd(dst, dst, dp_scratch); + } +} + +// Inner kernel. +void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n, + bool cfetch) +{ + int um_vecs = (unroll_m + 15) >> 4; + Label label_kernel_loop; + + L_aligned(label_kernel_loop); { + for (int h = 0; h < 4; h++) { + for (int j = 0; j < unroll_n; j++) { + const Zmm b = b_regs[j & 1]; + + vpbroadcastd(b, ptr[BO + isize * + (2 * j + 2 * h * unroll_n - offset_b)]); + dot_product(c_regs[0][j], b, a_regs[0]); + + if (j == 1 && !(h & 1)) + prefetch_b(ptr[BO + isize * (prefetch_size_b + + 2 * h * unroll_n - offset_b)]); + else if (j % 3 == 0) + prefetch_a(ptr[AO + isize * (prefetch_size_a + + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]); + + for (int i = 1; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + + if (cfetch && (j == std::min(1, unroll_n - 1))) { + if (h == 3) + lea(CO2, ptr[CO2 + LDC]); + else if (h < um_vecs) + prefetch_c(ptr[CO2 + (16 * h * size)]); + } + + if (h == 3 && j == std::min(3, unroll_n - 1)) + lea(AA, ptr[AA + (32 * isize)]); + } + + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * + (32 * i + 2 * (h + 1) * unroll_m - offset_a)]); + + if (h == 2) + prefetch_x(ptr[AA - (offset_a * isize)]); + } + + add(AO, 8 * isize * unroll_m); + add(BO, 8 * isize * unroll_n); + sub(LoopCount, 1); + jg(label_kernel_loop, T_NEAR); + } +} + +// k remainder loop for kernel. +void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m, + int unroll_n, int unroll_k, int bwidth) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + + for (int h = 0; h < unroll_k; h++) { + for (int j = 0; j < unroll_n; j++) { + Zmm b = b_regs[j & 1]; + auto b_src = ptr[BO + (-isize * offset_b + + bwidth * (j + h * unroll_n))]; + + switch (bwidth) { + case 4: + vpbroadcastd(b, b_src); + break; + case 2: + vpbroadcastw(b, b_src); + break; + case 1: + vpbroadcastb(b, b_src); + break; + } + for (int i = 0; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + } + + if (unroll_k > 1) { + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i + + (h + 1) * 2 * unroll_m - offset_a)]); + } + } + + add(AO, unroll_k * unroll_m * bwidth); + add(BO, unroll_k * unroll_n * bwidth); +} + +// Inner loop. +void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + int stage1 = unroll_n, stage2 = unroll_n; + + Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2; + Label label_k_main_loop_3, label_kernel_loop_3; + Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2; + Label label_k_rem_1, label_update_begin; + + mov(AO, A); + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]); + + mov(LoopCount, K); + sar(LoopCount, 4); + jle(label_k_remainder_loop_begin, T_NEAR); + + // Main k loops, broken into three parts to time C prefetching. + sub(LoopCount, stage1 + stage2); + jle(label_k_main_loop_2, T_NEAR); + + kernel_loop(unroll_m, unroll_n, false); + + L_aligned(label_k_main_loop_2); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage1); + jle(label_k_main_loop_3, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + L_aligned(label_k_main_loop_3); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage2); + jle(label_k_remainder_loop_begin, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + // k remainder handling + L_aligned(label_k_remainder_loop_begin); + mov(LoopCount, K); + test(LoopCount, 8); + je(label_k_rem_4, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 2, 4); + + L_aligned(label_k_rem_4); + mov(LoopCount, K); + test(LoopCount, 4); + je(label_k_rem_2, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 1, 4); + + L_aligned(label_k_rem_2); + mov(LoopCount, K); + test(LoopCount, 2); + je(label_k_rem_1, T_NEAR); + + Zmm zero = zmm6; + Zmm tmp = zmm5; + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 2); + + L_aligned(label_k_rem_1); + mov(LoopCount, K); + test(LoopCount, 1); + je(label_update_begin, T_NEAR); + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]); + vpunpcklbw(tmp, a, zero); + vpunpckhbw(a, a, zero); + vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 1); + + // Add offsets and update C. + L_aligned(label_update_begin); + + if (enable_offset_r) { + // Add row offsets. + mov(rax, coffset_ry); + for (int j = 0; j < unroll_n; j++) { + Zmm row_offset = zmm0; + + vbroadcastss(row_offset, ptr[rax + size * j]); + + for (int i = 0; i < um_vecs; i++) + vpaddd(c_regs[i][j], c_regs[i][j], row_offset); + } + add(coffset_ry, size * unroll_n); + } + + if (enable_offset_c) { + // Add column offsets. + mov(rax, coffset_cy); + for (int i = 0; i < um_vecs; i++) { + Zmm col_offset = zmm0; + + c_load(col_offset, ptr[rax + size * 16 * i], unroll_m); + + for (int j = 0; j < unroll_n; j++) + vpaddd(c_regs[i][j], c_regs[i][j], col_offset); + } + } + + Reg64 LDC3 = rax; + lea(LDC3, ptr[LDC + LDC * 2]); + + // C updates. + int c_off_j = 0; + for (int j = 0; j < unroll_n; j++) { + if (j > 0 && (j & 3) == 0) { + lea(CO1, ptr[CO1 + LDC * 4]); + c_off_j += 4; + } + + int jj = j - c_off_j; + + for (int i = 0; i < um_vecs; i++) { + Zmm c = c_regs[i][j]; + Zmm c_old = zmm0; + decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj; + + auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i]; + + if (beta_zero) + c_store(c_mem, c, unroll_m); + else { + c_load(c_old, c_mem, unroll_m); + vpaddd(c_old, c, c_old); + c_store(c_mem, c_old, unroll_m); + } + + vpxorq(c, c, c); + } + } + + lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]); +} + +// Outer loop. +void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y, + Label *&cur_outerloop_label) +{ + Label label_m_loop, label_n_loop, label_n_remainder_loops[6]; + + L(*cur_outerloop_label); + cur_outerloop_label++; + if (unroll_x >= IGEMM_UNROLL_M) { + mov(J, M); + cmp(J, unroll_x); + jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label. + } else { + test(J, unroll_x); + jle(*cur_outerloop_label, T_NEAR); + } + + L_aligned(label_m_loop); { + mov(CO1, C); + add(C, unroll_x * size); + + mov(BO, B); + + mov(AA, K); + imul(AA, AA, unroll_x * isize); + lea(AA, ptr[A + AA + isize * prefetch_size_a]); + + if (enable_offset_c) { + mov(rax, coffset_cx); + mov(coffset_cy, rax); + add(rax, unroll_x * size); + mov(coffset_cx, rax); + } + + if (enable_offset_r) { + mov(rax, coffset_rx); + mov(coffset_ry, rax); + } + + mov(I, N); + cmp(I, unroll_y); + jl(label_n_remainder_loops[0], T_NEAR); + + L_aligned(label_n_loop); { + innerloop(unroll_x, unroll_y); + sub(I, unroll_y); + cmp(I, unroll_y); + jge(label_n_loop, T_NEAR); + } + + align(16); + + int label_idx = 0; + for (int uy = 16; uy > 0; uy >>= 1) { + L(label_n_remainder_loops[label_idx++]); + if (unroll_y > uy) { + test(I, uy); + jle(label_n_remainder_loops[label_idx], T_NEAR); + + innerloop(unroll_x, uy); + align(16); + } + } + L(label_n_remainder_loops[label_idx]); + + mov(A, AO); + if (unroll_x >= IGEMM_UNROLL_M) { + sub(J, unroll_x); + cmp(J, unroll_x); + jge(label_m_loop); + } + } + + align(16); +} + +void jit_avx512_core_gemm_s8u8s32_kern::generate() +{ + // Prologue + preamble(); + sub(rsp, stack_alloc_size); + + if (is_windows) { + mov(A, arg_a); + mov(B, arg_b); + } + + mov(C, arg_c); + mov(LDC, arg_ldc); + + sub(A, -offset_a * isize); + sub(B, -offset_b * isize); + + mov(M, qword[M]); + mov(N, qword[N]); + mov(K, qword[K]); + + lea(LDC, ptr[LDC * size]); + + if (enable_offset_c) { + mov(rax, arg_coffset_c); + mov(coffset_cx, rax); + } + if (enable_offset_r) { + mov(rax, arg_coffset_r); + mov(coffset_rx, rax); + } + + for (int i = 0; i < (max_unroll_m >> 4); i++) { + for (int j = 0; j < max_unroll_n; j++) { + auto &c = c_regs[i][j]; + vpxorq(c, c, c); + } + } + + if (!vnni) { + mov(rax, 1); + movq(make_xmm(ones), rax); + vpbroadcastw(ones, make_xmm(ones)); + } + + Label outerloop_labels[8]; + Label *cur_outerloop_label = &outerloop_labels[0]; + + // Main m loop. + outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label); + + // m remainder loops. + for (int um = 32; um > 0; um >>= 1) + if (IGEMM_UNROLL_M > um) + outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label); + + L(*cur_outerloop_label); + + // Epilogue. + add(rsp, stack_alloc_size); + postamble(); +} + + +jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool + beta_zero_, bool enable_offset_c_, bool enable_offset_r_) : + jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0), + arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0), + coffset_rx(0), coffset_ry(0) +{ + beta_zero = beta_zero_; + enable_offset_c = enable_offset_c_; + enable_offset_r = enable_offset_r_; + vnni = mayiuse(avx512_core_vnni); + + // Assign integer registers + M = is_windows ? rcx : rdi; + N = is_windows ? rdx : rsi; + K = is_windows ? r8 : rdx; + A = is_windows ? rsi : r8; + B = r9; + C = r10; + LDC = r11; + I = r12; + J = r13; + LoopCount = rax; + AO = r14; + BO = r15; + CO1 = rbx; + CO2 = rbp; + AA = is_windows ? rdi : rcx; + + // Assign vector registers + dp_scratch = zmm6; + ones = zmm7; + for (int i = 0; i < (max_unroll_m >> 4); i++) + a_regs[i] = Zmm(i); + b_regs[0] = zmm4; + b_regs[1] = zmm5; + + int rn = 0; + for (int i = 0; i < (max_unroll_m >> 4); i++) + for (int j = 0; j < max_unroll_n; j++) + c_regs[i][j] = Zmm(8 + rn++); + + // Assign stack variables. + stack_alloc_size = 32; + auto args_offset = stack_alloc_size + get_size_of_abi_save_regs() + + 8 + (is_windows ? 48 : 0); + + arg_a = ptr[rsp + (args_offset - 16)]; + arg_b = ptr[rsp + (args_offset - 8)]; + arg_c = ptr[rsp + (args_offset + 0)]; + arg_ldc = ptr[rsp + (args_offset + 8)]; + arg_coffset_c = ptr[rsp + (args_offset + 16)]; + arg_coffset_r = ptr[rsp + (args_offset + 24)]; + + coffset_cx = qword[rsp + 0]; + coffset_cy = qword[rsp + 8]; + coffset_rx = qword[rsp + 16]; + coffset_ry = qword[rsp + 24]; + + generate(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp new file mode 100644 index 0000000000..e8efcc1cc8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef IGEMM_KERNEL_GENERATOR_HPP +#define IGEMM_KERNEL_GENERATOR_HPP + +#include "jit_generator.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator { +public: + jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_, + bool enable_offset_r_); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern); + +protected: + bool beta_zero; + bool enable_offset_c, enable_offset_r; + bool vnni; + + void prefetch_a(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_b(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_c(const Xbyak::Address &src) { + prefetchw(src); + } + void prefetch_x(const Xbyak::Address &src) { + prefetcht0(src); + } + + void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems); + void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems); + + void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1, + const Xbyak::Xmm &src2); + void kernel_loop(int unroll_m, int unroll_n, bool cfetch); + void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth); + void innerloop(int unroll_m, int unroll_n); + void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); + + void generate(); + + +private: + static const int IGEMM_UNROLL_M = 48; + static const int IGEMM_UNROLL_N = 8; + + static const int isize = 2; + static const int size = 4; + + // Prefetch configuration + static const int prefetch_size_a = 32 * 5; + static const int prefetch_size_b = 32 * 4; + + static const int offset_a = 256, offset_b = 256; + static const int max_unroll_m = 48, max_unroll_n = 8; + + // Integer register assignments + Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount; + Xbyak::Reg64 AO, BO, CO1, CO2, AA; + + // Vector register assignments + Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2]; + Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n]; + + // Stack variable assignments + int stack_alloc_size; + Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r; + Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry; + + void L_aligned(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } +}; + +} +} +} + +#endif /* header guard */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp new file mode 100644 index 0000000000..4f0b10dadd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "gemv.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) { + + blas_t arg_gemv = *arg; + + if ((arg -> offsetc == FIX_OFFSET) && // Fix offset + (arg -> ao == 0) && + (arg -> bo == 0) && + (arg -> co[0] == 0) && + (*(arg -> alpha) == 1.0f) && + ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) { + + if (arg -> n == 1) { + + if (arg -> transa == 1) { // A transpose + arg_gemv.n = arg -> k; + arg_gemv.ldc = 1; + arg_gemv.swap = 0; + if (arg -> transb == 0) { // B non transpose + arg_gemv.ldb = 1; + } + // B transpose arg_gemv.ldb = arg -> ldb + gemv_threading_driver(&arg_gemv); + return 1; + } + } + + if (arg -> m == 1) { + + if (arg -> transb == 0) { // B non transpose + arg_gemv.transa = 1; + arg_gemv.m = arg -> n; + arg_gemv.n = arg -> k; + arg_gemv.a = (int8_t *) arg -> b; + arg_gemv.lda = arg -> ldb; + arg_gemv.b = (uint8_t *) arg -> a; + arg_gemv.swap = 1; + if (arg -> transa == 0) { // A non transpose + arg_gemv.ldb = arg -> lda; + } + else { // A transpose + arg_gemv.ldb = 1; + } + gemv_threading_driver(&arg_gemv); + return 1; + } + } + } + + return 0; +} + + +int gemv_kernel_driver(blas_t *arg) { + + dim_t m = arg -> m; + dim_t n = arg -> n; + uint8_t *a = (uint8_t *) arg -> a; + dim_t lda = arg -> lda; + int8_t *b = (int8_t *) arg -> b; + float beta = *(arg -> beta); + + if (arg -> swap) { + arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c); + } + else { + arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a, + arg -> lda, arg -> b, *(arg -> beta), arg -> c); + } + + return 0; +} + +int gemv_threading_driver(blas_t *arg) { + + dim_t nthr_m, nthr_n = 1; + dim_t MB, NB, UM = 16, UN = 64; + dim_t BLOCKM = 192, BLOCKN = 3072; + int status; + dim_t i; + + dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + uint8_t *new_x = NULL; + int32_t *tmp_y = NULL, *new_y = NULL; + + dim_t m = arg -> m, n = arg -> n; + + blas_t arg_seq = *arg; + float zero = 0.0f; + + nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr); + MB = m / nthr_m; + MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM; + nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1; + nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr); + + while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) { + nthr_n++; + } + + NB = n / nthr_n; + NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN; + nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1; + nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m); + + nthr = nthr_m * nthr_n; + + if (arg -> ldb != 1) { + new_x = (uint8_t *)malloc(n, 64); + if (new_x == NULL) + return 1; + for (i = 0; i < n; i++) { + new_x[i] = (arg -> b)[i * arg -> ldb]; + } + arg_seq.b = new_x; + arg_seq.ldb = 1; + } + else new_x = (uint8_t *) arg -> b; + + if (arg -> ldc != 1) { + new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64); + if (new_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + // GEMV computation + if (nthr == 1) { + + if (arg -> ldc != 1) { + if (*(arg -> beta) != 0.0f) { + for (i = 0; i < m; i++) { + new_y[i] = arg -> c[i * arg -> ldc]; + } + } + } + + status = gemv_kernel_driver(&arg_seq); + + if (arg -> ldc != 1) { + for (i = 0; i < m; i++) { + arg -> c[i * arg -> ldc] = new_y[i]; + } + } + + if (arg -> ldb != 1) { + free(new_x); + } + if (arg -> ldc != 1) { + free(new_y); + } + return status; + } + + if (nthr_n > 1) { + tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE); + if (tmp_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + parallel_nd((int) nthr, [&](const dim_t ithr) { + + dim_t m_from, m_to, myM; + dim_t n_from, n_to, myN; + + dim_t n_id, m_id; + dim_t loc_incy = 1; + int32_t *loc_y; + + blas_t arg_loc = arg_seq; + int j; + + m_id = ithr / nthr_n; + n_id = ithr % nthr_n; + + m_from = MB * m_id; + m_to = MB * (m_id + 1); + if ((m_to > m) || (m_id == nthr_m - 1)) + m_to = m; + + myM = m_to - m_from; + + n_from = NB * n_id; + n_to = NB * (n_id + 1); + if ((n_to > n) || (n_id == nthr_n - 1)) + n_to = n; + + myN = n_to - n_from; + + if (n_id != 0) { + arg_loc.beta = &zero; + loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from; + } + else { + if (arg -> ldc == 1) { + loc_y = arg_seq.c + m_from; + } + else { + // need to copy the block of c in new_y + loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t)); + if (*(arg -> beta) != 0.0f) { + for (j = 0; j < myM; j++) { + loc_y[j] = arg -> c[(m_from + j) * arg -> ldc]; + } + } + } + } + + arg_loc.m = myM; + arg_loc.n = myN; + arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from; + arg_loc.b = arg_seq.b + n_from; + arg_loc.c = loc_y; + arg_loc.ldc = loc_incy; + + gemv_kernel_driver(&arg_loc); + + if ((n_id == 0) && (arg -> ldc != 1)) { + for (j = 0; j < myM; j++) { + arg -> c[(m_from + j) * arg -> ldc] = loc_y[j]; + } + } + + }); + + if (nthr_n > 1) { + parallel_nd((int) nthr_m, [&](const dim_t ithr) { + + dim_t j, j_from, j_to, ii; + int32_t acc; + + j_from = MB * ithr; + j_to = MB * (ithr + 1); + if ((j_to > m) || (ithr == nthr - 1)) + j_to = m; + + for (j = j_from; j < j_to; j++) { + acc = 0; + for (ii = 0; ii < nthr_n - 1; ii++) { + acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j]; + } + (arg -> c)[j * arg -> ldc] += acc; + } + }); + free(tmp_y); + } + + if (arg -> ldb != 1) { + free(new_x); + } + + if (arg -> ldc != 1) { + free(new_y); + } + + return 0; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp new file mode 100644 index 0000000000..c57a8c1d12 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp @@ -0,0 +1,411 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" + +#ifdef _WIN32 +#define is_windows 1 +#else +#define is_windows 0 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b, + Xbyak::Zmm a, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, + int use_vnni) { + + if (use_vnni) { + if (swap) + vpdpbusd(acc, a, b); + else + vpdpbusd(acc, b, a); + } + + else { + if (swap) + vpmaddubsw(tmp, a, b); + else + vpmaddubsw(tmp, b, a); + vpmaddwd(tmp, tmp, one); + vpaddd(acc, tmp, acc); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx, + int b_idx, int nreg_acc, + Xbyak::Reg64 A, Xbyak::Reg64 lda, + Xbyak::Reg64 X, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, int use_vnni, + int use_mask, Xbyak::Opmask mask_n) { + + int i; + int nreg_A = nreg_acc / 2 + (nreg_acc % 2); + + // load X + j + if (use_mask) + vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]); + else + vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]); + + xor_(r14, r14); + // load values of A + for (i = 0; i < nreg_A; i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A; i++) { + // vnni (acc, b, a, tmp, one, swap, use_vnni) + vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A, + Xbyak::Zmm B, Xbyak::Zmm C, + Xbyak::Zmm D) { + + vshufi32x4(dest, A, C, 0x44); + vshufi32x4(A, A, C, 0xEE); + vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3 + + vshufi32x4(dest, B, D, 0x44); + vshufi32x4(B, B, D, 0xEE); + vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3 + + vshufi32x4(A, C, D, 0x88); + vshufi32x4(B, C, D, 0xDD); + vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi + +} + +void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y, + int start_a_idx, int start_acc_idx, + Xbyak::Xmm beta, int use_mask, + Xbyak::Opmask mask_m) { + + int l, i, k, j, last_it; + Xbyak::Label store_label; + + l = 0; + for (k = 0; k < nreg_acc; k += 8) { + for (i = 0, j = k; i < 8; i += 4, j += 2) { + if (j < nreg_acc) { + // shuffle per block of 4 registers + shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest + Xbyak::Zmm(start_acc_idx + j), // A = acc0 + Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1 + Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4 + Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5 + + // extract low and high from dest and hadd + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0); + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1); + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l + 1), + Xbyak::Ymm(start_a_idx + l + 2)); + } + l++; + } + + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l - 2), + Xbyak::Ymm(start_a_idx + l - 1)); + + l++; + } + + // eventually add with C and store new value + vxorps(Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx)); + vucomiss(beta, Xbyak::Ymm(start_a_idx)); + je(store_label, T_NEAR); + + // beta = 1 + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + // load Y and add + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]); + else + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]); + + vpaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + k / 8)); + } + + // store + aligned_label(store_label); + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m); + else + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l)); + } + +} + +template +T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) { + + Xbyak::Opmask mask_n = k1, mask_m = k2; + Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label; + Xbyak::Label n_tail_label, update_c_label, end_label; + constexpr unsigned int n_labels = (1 << unroll_m) - 1; + Xbyak::Label m_tail_label_case[n_labels]; + Xbyak::Label n_loop_label_case[n_labels]; + Xbyak::Label n_tail_label_case[n_labels]; + Xbyak::Label update_c_label_case[n_labels]; + + int i, ii; + + Xbyak::Zmm one, tmp; + Xbyak::Reg64 n = abi_param2, m = abi_param1; + Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3; + Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4; + Xbyak::Reg64 X = is_windows ? rdi : r8; + Xbyak::Xmm beta = xmm1; + Xbyak::Reg64 Y = is_windows ? rsi : r9; + + bool swap = !std::is_same::value; + + // Windows: read on the stack lda, X, beta, Y + + int zmm_idx = 1; + int nreg_acc = 1 << unroll_m; + int nreg_A = 1 << (unroll_m - 1); + int nreg_A_acc = nreg_acc + nreg_A; + + if (!use_vnni) { + // set a zmm register to one + tmp = Xbyak::Zmm(0); + one = Xbyak::Zmm(zmm_idx + 1); + zmm_idx += 2; // one + tmp + } + else { + beta = xmm0; + } + + preamble(); + + if (is_windows) { + mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]); + mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]); + movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]); + mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]); + } + + if (use_vnni && !is_windows) { + movaps(beta, xmm1); + } + + mov(rax, (1 << unroll_n) - 1); + kmovq(k3, rax); + + and_(rax, n); // rax contains n & ((1 << unroll_n) - 1) + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_n, rbx); + // mask_n set (AVX512 only), can use rax and rbx again + + // set mask_m for update of the C matrix + // load/store on the C matrix use Ymm so tail according to Ymm size + mov(rax, 7); // 8 * 32 = 256 Ymm size + and_(rax, m); // rax contains m & 7 + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_m, rbx); + // mask_m set (AVX512 only), can use rax and rbx again + + // setup register of ones when VNNI instructions not available + if (!use_vnni) { + vmovdqu16(one, ptr[rip + one_label]); + } + + // M loop + // base pointer for A rax contains a + i * lda + // Loop stop when rax >= a + (m & mask_um) * lda = rbx + // loop increment r10 = um * lda + // rbp = Y + i + mov(rax, A); // i = 0 + mov(rbx, m); + and_(rbx, mask_um); + imul(rbx, lda); + add(rbx, A); + mov(r10, lda); + sal(r10, unroll_m); + mov(rbp, Y); + + // N loop + // base pointer for X r11 contains x + j + // Loop stop when r11 >= x + n & mask_un = r12 + // loop increment un + // r13 = rax + j = A + i * lda + j + mov(r12, n); + and_(r12, mask_un); + add(r12, X); + + // M loop + aligned_label(m_loop_label); + cmp(rax, rbx); + jge(m_tail_label, T_NEAR); + + // enter M loop + for(i = 0; i < nreg_acc; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label); + cmp(r11, r12); + jge(n_tail_label, T_NEAR); + + // enter N loop + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label, T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label); + + ktestq(mask_n, k3); + je(update_c_label, T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label); + + update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m); + + // increment rax with um * lda + add(rax, r10); + add(rbp, 1 << (unroll_m + 2)); + jmp(m_loop_label, T_NEAR); + // end M loop + + // M tail + aligned_label(m_tail_label); + + // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1 + mov(r10, m); + and_(r10, (1 << unroll_m) - 1); + for (ii = 1; ii < 1 << unroll_m; ii++) { + aligned_label(m_tail_label_case[ii-1]); + cmp(r10, ii); + if (ii == (1 << unroll_m) - 1) + jne(end_label, T_NEAR); + else + jne(m_tail_label_case[ii], T_NEAR); + + // m_tail = i, use i accumulators + + for(i = 0; i < ii; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label_case[ii - 1]); + cmp(r11, r12); + jge(n_tail_label_case[ii - 1], T_NEAR); + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label_case[ii - 1], T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label_case[ii - 1]); + ktestq(mask_n, k3); + je(update_c_label_case[ii - 1], T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label_case[ii - 1]); + update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m); + + if (ii < ((1 << unroll_m) - 1)) + jmp(end_label, T_NEAR); + } + + aligned_label(end_label); + + postamble(); + + if (!use_vnni) { + aligned_label(one_label); + for (i = 0; i < size_vec_reg/8; i++) + dq(0x0001000100010001); + } + + return (T) getCode(); +} + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate(int); + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate(int); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp new file mode 100644 index 0000000000..9ea23a5f56 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp @@ -0,0 +1,64 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemv_s8u8s32_kern : jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern); + + // assumes untoll_{m,n} are a power of 2 + static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m + const int mask_um = 0xFFFFFFF0; + static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n + const int mask_un = 0xFFFFFFC0; + const int size_vec_reg = 64; // bytes + + void aligned_label(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } + + void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int); + void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64, + Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask); + void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm); + void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask); + +public: + jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {}; + + // m, n, alpha, a, lda, x, beta, y + typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + template + T generate(int use_vnni); + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp new file mode 100644 index 0000000000..544cd2ff25 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp @@ -0,0 +1,819 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l170; +Xbyak::Label l1f0; +Xbyak::Label l20; +Xbyak::Label l224; +Xbyak::Label l234; +Xbyak::Label l240; +Xbyak::Label l254; +Xbyak::Label l32c; +Xbyak::Label l34; +Xbyak::Label l388; +Xbyak::Label l3b0; +Xbyak::Label l3c0; +Xbyak::Label l3cc; +Xbyak::Label l3dc; +Xbyak::Label l454; +Xbyak::Label l48c; +Xbyak::Label l4a8; +Xbyak::Label l4b8; +Xbyak::Label l4c4; +Xbyak::Label l4d8; +Xbyak::Label l570; +Xbyak::Label l5c4; +Xbyak::Label l5f0; +Xbyak::Label l60c; +Xbyak::Label l61c; +Xbyak::Label l628; +Xbyak::Label l638; +Xbyak::Label l6b0; +Xbyak::Label l6f4; +Xbyak::Label l720; +Xbyak::Label l73c; +Xbyak::Label l74c; +Xbyak::Label l758; +Xbyak::Label l76c; +Xbyak::Label l804; +Xbyak::Label l858; +Xbyak::Label l88c; +Xbyak::Label l8a4; +Xbyak::Label l8b2; +Xbyak::Label l8bc; +Xbyak::Label l8cc; +Xbyak::Label l944; +Xbyak::Label l98c; +Xbyak::Label l9b0; +Xbyak::Label l9c8; +Xbyak::Label l9d8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l234, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + mov(I, M); + sar(I, 0x2); + jle(l170, T_NEAR); + align(4); + +L(l34); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + movdqu(xmm0, xword[A1-0x60]); + movdqu(xmm1, xword[A1+LDA*1-0x60]); + movdqu(xmm2, xword[A1+LDA*2-0x60]); + movdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x30], xmm2); + sub(B, -192); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(l170); + test(M, 0x2); + jle(l1f0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + movdqu(xmm4, xword[A1-0x70]); + movdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + movdqa(xmm6, xmm0); + punpcklbw(xmm0, xmm3); + punpckhbw(xmm6, xmm3); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm6); + movdqa(xmm6, xmm1); + punpcklbw(xmm1, xmm4); + punpckhbw(xmm6, xmm4); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm6); + movdqa(xmm6, xmm2); + punpcklbw(xmm2, xmm5); + punpckhbw(xmm6, xmm5); + movdqu(xword[B-0x40], xmm2); + movdqu(xword[B-0x30], xmm6); + sub(B, -96); + align(4); + +L(l1f0); + test(M, 0x1); + jle(l224, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l224); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(l234); + cmp(N, 0x20); + jl(l3c0, T_NEAR); + align(4); + +L(l240); + mov(A1, A); + add(A, 0x20); + mov(I, M); + sar(I, 0x2); + jle(l32c, T_NEAR); + align(4); + +L(l254); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l254, T_NEAR); + align(4); + +L(l32c); + test(M, 0x2); + jle(l388, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l388); + test(M, 0x1); + jle(l3b0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3b0); + sub(N, 0x20); + cmp(N, 0x20); + jge(l240, T_NEAR); + align(4); + +L(l3c0); + cmp(N, 0x10); + jl(l4b8, T_NEAR); + align(4); + +L(l3cc); + mov(A1, A); + add(A, 0x10); + mov(I, M); + sar(I, 0x2); + jle(l454, T_NEAR); + align(4); + +L(l3dc); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l3dc, T_NEAR); + align(4); + +L(l454); + test(M, 0x2); + jle(l48c, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(l48c); + test(M, 0x1); + jle(l4a8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l4a8); + sub(N, 0x10); + cmp(N, 0x10); + jge(l3cc, T_NEAR); + align(4); + +L(l4b8); + cmp(N, 0x8); + jl(l61c, T_NEAR); + align(4); + +L(l4c4); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(l570, T_NEAR); + align(4); + +L(l4d8); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l4d8, T_NEAR); + align(4); + +L(l570); + test(M, 0x4); + jle(l5c4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l5c4); + test(M, 0x2); + jle(l5f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l5f0); + test(M, 0x1); + jle(l60c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l60c); + sub(N, 0x8); + cmp(N, 0x8); + jge(l4c4, T_NEAR); + align(4); + +L(l61c); + cmp(N, 0x4); + jl(l74c, T_NEAR); + align(4); + +L(l628); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l6b0, T_NEAR); + align(4); + +L(l638); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l638, T_NEAR); + align(4); + +L(l6b0); + test(M, 0x4); + jle(l6f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l6f4); + test(M, 0x2); + jle(l720, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l720); + test(M, 0x1); + jle(l73c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l73c); + sub(N, 0x4); + cmp(N, 0x4); + jge(l628, T_NEAR); + align(4); + +L(l74c); + cmp(N, 0x2); + jl(l8b2, T_NEAR); + align(4); + +L(l758); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l804, T_NEAR); + align(4); + +L(l76c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l76c, T_NEAR); + align(4); + +L(l804); + test(M, 0x4); + jle(l858, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l858); + test(M, 0x2); + jle(l88c, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l88c); + test(M, 0x1); + jle(l8a4, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l8a4); + sub(N, 0x2); + cmp(N, 0x2); + jge(l758, T_NEAR); + align(4); + +L(l8b2); + cmp(N, 0x1); + jl(l9d8, T_NEAR); + align(4); + +L(l8bc); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l944, T_NEAR); + align(4); + +L(l8cc); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l8cc, T_NEAR); + align(4); + +L(l944); + test(M, 0x4); + jle(l98c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l98c); + test(M, 0x2); + jle(l9b0, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l9b0); + test(M, 0x1); + jle(l9c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l9c8); + sub(N, 0x1); + cmp(N, 0x1); + jge(l8bc, T_NEAR); + align(4); + +L(l9d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp new file mode 100644 index 0000000000..1c11fc6cef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp @@ -0,0 +1,2209 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1014; +Xbyak::Label l1390; +Xbyak::Label l159c; +Xbyak::Label l173c; +Xbyak::Label l18e4; +Xbyak::Label l1a7c; +Xbyak::Label l1a8c; +Xbyak::Label l1a98; +Xbyak::Label l1ab4; +Xbyak::Label l1c64; +Xbyak::Label l1d74; +Xbyak::Label l1e50; +Xbyak::Label l1f2c; +Xbyak::Label l1ffc; +Xbyak::Label l20; +Xbyak::Label l200c; +Xbyak::Label l2018; +Xbyak::Label l2034; +Xbyak::Label l2110; +Xbyak::Label l21a0; +Xbyak::Label l2210; +Xbyak::Label l2284; +Xbyak::Label l22f0; +Xbyak::Label l2300; +Xbyak::Label l230c; +Xbyak::Label l2324; +Xbyak::Label l2398; +Xbyak::Label l23e8; +Xbyak::Label l242c; +Xbyak::Label l2474; +Xbyak::Label l24b4; +Xbyak::Label l24c4; +Xbyak::Label l24d0; +Xbyak::Label l24e8; +Xbyak::Label l2520; +Xbyak::Label l254c; +Xbyak::Label l2578; +Xbyak::Label l25a8; +Xbyak::Label l25c8; +Xbyak::Label l25d6; +Xbyak::Label l25e0; +Xbyak::Label l25f0; +Xbyak::Label l260c; +Xbyak::Label l262c; +Xbyak::Label l264c; +Xbyak::Label l2668; +Xbyak::Label l2680; +Xbyak::Label l2690; +Xbyak::Label l44; +Xbyak::Label l58c; +Xbyak::Label l8b0; +Xbyak::Label lb14; +Xbyak::Label ld84; +Xbyak::Label lfdc; +Xbyak::Label lfec; +Xbyak::Label lff8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lfec, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l58c, T_NEAR); + align(4); + +L(l44); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0x100], xmm4); + movdqu(xword[B+0x1c0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0x110], xmm4); + movdqu(xword[B+0x1d0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0x120], xmm4); + movdqu(xword[B+0x1e0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0x130], xmm4); + movdqu(xword[B+0x1f0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movdqu(xword[B+0x140], xmm4); + movdqu(xword[B+0x200], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movdqu(xword[B+0x150], xmm4); + movdqu(xword[B+0x210], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movdqu(xword[B+0x160], xmm4); + movdqu(xword[B+0x220], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movdqu(xword[B+0x170], xmm4); + movdqu(xword[B+0x230], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movdqu(xword[B+0x180], xmm4); + movdqu(xword[B+0x240], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movdqu(xword[B+0x190], xmm4); + movdqu(xword[B+0x250], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movdqu(xword[B+0x1a0], xmm4); + movdqu(xword[B+0x260], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + movdqu(xword[B+0x1b0], xmm4); + movdqu(xword[B+0x270], xmm3); + sub(A1, -16); + sub(B, -768); + dec(I); + jg(l44, T_NEAR); + align(4); + +L(l58c); + test(M, 0x8); + jle(l8b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + sub(A1, -8); + sub(B, -384); + align(4); + +L(l8b0); + test(M, 0x4); + jle(lb14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x30], xmm0); + sub(A1, -4); + sub(B, -192); + align(4); + +L(lb14); + test(M, 0x2); + jle(ld84, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(ld84); + test(M, 0x1); + jle(lfdc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lfdc); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(lfec); + cmp(N, 0x20); + jl(l1a8c, T_NEAR); + align(4); + +L(lff8); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1390, T_NEAR); + align(4); + +L(l1014); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movdqu(xword[B+0x80], xmm4); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x90], xmm4); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movdqu(xword[B+0xa0], xmm4); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movdqu(xword[B+0xb0], xmm4); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0xc0], xmm4); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0xd0], xmm4); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0xe0], xmm4); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0xf0], xmm4); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(l1014, T_NEAR); + align(4); + +L(l1390); + test(M, 0x8); + jle(l159c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l159c); + test(M, 0x4); + jle(l173c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l173c); + test(M, 0x2); + jle(l18e4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l18e4); + test(M, 0x1); + jle(l1a7c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l1a7c); + sub(N, 0x20); + cmp(N, 0x20); + jge(lff8, T_NEAR); + align(4); + +L(l1a8c); + cmp(N, 0x10); + jl(l200c, T_NEAR); + align(4); + +L(l1a98); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1c64, T_NEAR); + align(4); + +L(l1ab4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movdqu(xword[B], xmm4); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B+0x10], xmm4); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + movdqu(xword[B+0x30], xmm4); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l1ab4, T_NEAR); + align(4); + +L(l1c64); + test(M, 0x8); + jle(l1d74, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l1d74); + test(M, 0x4); + jle(l1e50, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l1e50); + test(M, 0x2); + jle(l1f2c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l1f2c); + test(M, 0x1); + jle(l1ffc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l1ffc); + sub(N, 0x10); + cmp(N, 0x10); + jge(l1a98, T_NEAR); + align(4); + +L(l200c); + cmp(N, 0x8); + jl(l2300, T_NEAR); + align(4); + +L(l2018); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2110, T_NEAR); + align(4); + +L(l2034); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2034, T_NEAR); + align(4); + +L(l2110); + test(M, 0x8); + jle(l21a0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l21a0); + test(M, 0x4); + jle(l2210, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l2210); + test(M, 0x2); + jle(l2284, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2284); + test(M, 0x1); + jle(l22f0, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l22f0); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2018, T_NEAR); + align(4); + +L(l2300); + cmp(N, 0x4); + jl(l24c4, T_NEAR); + align(4); + +L(l230c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2398, T_NEAR); + align(4); + +L(l2324); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l2324, T_NEAR); + align(4); + +L(l2398); + test(M, 0x8); + jle(l23e8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l23e8); + test(M, 0x4); + jle(l242c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l242c); + test(M, 0x2); + jle(l2474, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2474); + test(M, 0x1); + jle(l24b4, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l24b4); + sub(N, 0x4); + cmp(N, 0x4); + jge(l230c, T_NEAR); + align(4); + +L(l24c4); + cmp(N, 0x2); + jl(l25d6, T_NEAR); + align(4); + +L(l24d0); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2520, T_NEAR); + align(4); + +L(l24e8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l24e8, T_NEAR); + align(4); + +L(l2520); + test(M, 0x8); + jle(l254c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l254c); + test(M, 0x4); + jle(l2578, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2578); + test(M, 0x2); + jle(l25a8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l25a8); + test(M, 0x1); + jle(l25c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l25c8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l24d0, T_NEAR); + align(4); + +L(l25d6); + cmp(N, 0x1); + jl(l2690, T_NEAR); + align(4); + +L(l25e0); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l260c, T_NEAR); + align(4); + +L(l25f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l25f0, T_NEAR); + align(4); + +L(l260c); + test(M, 0x8); + jle(l262c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l262c); + test(M, 0x4); + jle(l264c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l264c); + test(M, 0x2); + jle(l2668, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l2668); + test(M, 0x1); + jle(l2680, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l2680); + sub(N, 0x1); + cmp(N, 0x1); + jge(l25e0, T_NEAR); + align(4); + +L(l2690); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp new file mode 100644 index 0000000000..56c36ee14a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp @@ -0,0 +1,564 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l118; +Xbyak::Label l1a8; +Xbyak::Label l20; +Xbyak::Label l218; +Xbyak::Label l28c; +Xbyak::Label l2f8; +Xbyak::Label l308; +Xbyak::Label l314; +Xbyak::Label l32c; +Xbyak::Label l3a0; +Xbyak::Label l3c; +Xbyak::Label l3f0; +Xbyak::Label l434; +Xbyak::Label l47c; +Xbyak::Label l4bc; +Xbyak::Label l4cc; +Xbyak::Label l4d8; +Xbyak::Label l4f0; +Xbyak::Label l528; +Xbyak::Label l554; +Xbyak::Label l580; +Xbyak::Label l5b0; +Xbyak::Label l5d0; +Xbyak::Label l5de; +Xbyak::Label l5e8; +Xbyak::Label l5f8; +Xbyak::Label l614; +Xbyak::Label l634; +Xbyak::Label l654; +Xbyak::Label l670; +Xbyak::Label l688; +Xbyak::Label l698; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l308, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l118, T_NEAR); + align(4); + +L(l3c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l3c, T_NEAR); + align(4); + +L(l118); + test(M, 0x8); + jle(l1a8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l1a8); + test(M, 0x4); + jle(l218, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l218); + test(M, 0x2); + jle(l28c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l28c); + test(M, 0x1); + jle(l2f8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2f8); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l308); + cmp(N, 0x4); + jl(l4cc, T_NEAR); + align(4); + +L(l314); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l3a0, T_NEAR); + align(4); + +L(l32c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l32c, T_NEAR); + align(4); + +L(l3a0); + test(M, 0x8); + jle(l3f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3f0); + test(M, 0x4); + jle(l434, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l434); + test(M, 0x2); + jle(l47c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l47c); + test(M, 0x1); + jle(l4bc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4bc); + sub(N, 0x4); + cmp(N, 0x4); + jge(l314, T_NEAR); + align(4); + +L(l4cc); + cmp(N, 0x2); + jl(l5de, T_NEAR); + align(4); + +L(l4d8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l528, T_NEAR); + align(4); + +L(l4f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l4f0, T_NEAR); + align(4); + +L(l528); + test(M, 0x8); + jle(l554, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l554); + test(M, 0x4); + jle(l580, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l580); + test(M, 0x2); + jle(l5b0, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l5b0); + test(M, 0x1); + jle(l5d0, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l5d0); + sub(N, 0x2); + cmp(N, 0x2); + jge(l4d8, T_NEAR); + align(4); + +L(l5de); + cmp(N, 0x1); + jl(l698, T_NEAR); + align(4); + +L(l5e8); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l614, T_NEAR); + align(4); + +L(l5f8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l5f8, T_NEAR); + align(4); + +L(l614); + test(M, 0x8); + jle(l634, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l634); + test(M, 0x4); + jle(l654, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l654); + test(M, 0x2); + jle(l670, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l670); + test(M, 0x1); + jle(l688, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l688); + sub(N, 0x1); + cmp(N, 0x1); + jge(l5e8, T_NEAR); + align(4); + +L(l698); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp new file mode 100644 index 0000000000..53e99d94de --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp @@ -0,0 +1,501 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l120; +Xbyak::Label l14c; +Xbyak::Label l168; +Xbyak::Label l178; +Xbyak::Label l184; +Xbyak::Label l194; +Xbyak::Label l20; +Xbyak::Label l20c; +Xbyak::Label l250; +Xbyak::Label l27c; +Xbyak::Label l298; +Xbyak::Label l2a8; +Xbyak::Label l2b4; +Xbyak::Label l2c8; +Xbyak::Label l34; +Xbyak::Label l360; +Xbyak::Label l3b4; +Xbyak::Label l3e8; +Xbyak::Label l400; +Xbyak::Label l40e; +Xbyak::Label l418; +Xbyak::Label l428; +Xbyak::Label l4a0; +Xbyak::Label l4e8; +Xbyak::Label l50c; +Xbyak::Label l524; +Xbyak::Label l534; +Xbyak::Label lcc; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l178, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(lcc, T_NEAR); + align(4); + +L(l34); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(lcc); + test(M, 0x4); + jle(l120, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l120); + test(M, 0x2); + jle(l14c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l14c); + test(M, 0x1); + jle(l168, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l168); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l178); + cmp(N, 0x4); + jl(l2a8, T_NEAR); + align(4); + +L(l184); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l20c, T_NEAR); + align(4); + +L(l194); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l194, T_NEAR); + align(4); + +L(l20c); + test(M, 0x4); + jle(l250, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l250); + test(M, 0x2); + jle(l27c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l27c); + test(M, 0x1); + jle(l298, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l298); + sub(N, 0x4); + cmp(N, 0x4); + jge(l184, T_NEAR); + align(4); + +L(l2a8); + cmp(N, 0x2); + jl(l40e, T_NEAR); + align(4); + +L(l2b4); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l360, T_NEAR); + align(4); + +L(l2c8); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l2c8, T_NEAR); + align(4); + +L(l360); + test(M, 0x4); + jle(l3b4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3b4); + test(M, 0x2); + jle(l3e8, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3e8); + test(M, 0x1); + jle(l400, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l400); + sub(N, 0x2); + cmp(N, 0x2); + jge(l2b4, T_NEAR); + align(4); + +L(l40e); + cmp(N, 0x1); + jl(l534, T_NEAR); + align(4); + +L(l418); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l4a0, T_NEAR); + align(4); + +L(l428); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l428, T_NEAR); + align(4); + +L(l4a0); + test(M, 0x4); + jle(l4e8, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4e8); + test(M, 0x2); + jle(l50c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l50c); + test(M, 0x1); + jle(l524, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l524); + sub(N, 0x1); + cmp(N, 0x1); + jge(l418, T_NEAR); + align(4); + +L(l534); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp new file mode 100644 index 0000000000..49a312fc88 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp @@ -0,0 +1,1283 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1024; +Xbyak::Label l1090; +Xbyak::Label l10d4; +Xbyak::Label l10fc; +Xbyak::Label l111a; +Xbyak::Label l1124; +Xbyak::Label l113c; +Xbyak::Label l11d4; +Xbyak::Label l1234; +Xbyak::Label l1278; +Xbyak::Label l129c; +Xbyak::Label l12bc; +Xbyak::Label l20; +Xbyak::Label l2a0; +Xbyak::Label l3c0; +Xbyak::Label l438; +Xbyak::Label l480; +Xbyak::Label l48c; +Xbyak::Label l4c8; +Xbyak::Label l5c; +Xbyak::Label l6a8; +Xbyak::Label l7b4; +Xbyak::Label l850; +Xbyak::Label l89c; +Xbyak::Label l8a8; +Xbyak::Label l8d0; +Xbyak::Label l9d0; +Xbyak::Label la64; +Xbyak::Label lab8; +Xbyak::Label lae8; +Xbyak::Label laf4; +Xbyak::Label lb14; +Xbyak::Label lc30; +Xbyak::Label lcc8; +Xbyak::Label ld1c; +Xbyak::Label ld54; +Xbyak::Label ld78; +Xbyak::Label ld84; +Xbyak::Label ld9c; +Xbyak::Label le58; +Xbyak::Label lebc; +Xbyak::Label lef8; +Xbyak::Label lf1c; +Xbyak::Label lf3c; +Xbyak::Label lf48; +Xbyak::Label lf60; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l480, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x2); + jle(l2a0, T_NEAR); + align(4); + +L(l5c); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1+LDA*1-0x80]); + vmovdqu(xmm2, xword[A1+LDA*2-0x80]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x80]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x60], xmm2); + vmovdqu(xword[B-0x50], xmm3); + vmovdqu(xmm0, xword[A1-0x70]); + vmovdqu(xmm1, xword[A1+LDA*1-0x70]); + vmovdqu(xmm2, xword[A1+LDA*2-0x70]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x70]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x20], xmm2); + vmovdqu(xword[B-0x10], xmm3); + vmovdqu(xmm0, xword[A1-0x60]); + vmovdqu(xmm1, xword[A1+LDA*1-0x60]); + vmovdqu(xmm2, xword[A1+LDA*2-0x60]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B+0x20], xmm2); + vmovdqu(xword[B+0x30], xmm3); + sub(B, -192); + dec(I); + jg(l5c, T_NEAR); + align(4); + +L(l2a0); + test(M, 0x2); + jle(l3c0, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vmovdqu(xmm6, xword[A1-0x80]); + vmovdqu(xmm4, xword[A1-0x70]); + vmovdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + vpunpcklbw(xmm3, xmm0, xmm6); + vpunpckhbw(xmm0, xmm0, xmm6); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm8, ymm8, ymm7); + vmovdqu(xword[B-0x80], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x70], xmm0); + vpunpcklbw(xmm3, xmm1, xmm4); + vpunpckhbw(xmm0, xmm1, xmm4); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm10, ymm10, ymm7); + vmovdqu(xword[B-0x60], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x50], xmm0); + vpunpcklbw(xmm3, xmm2, xmm5); + vpunpckhbw(xmm0, xmm2, xmm5); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm12, ymm12, ymm7); + vmovdqu(xword[B-0x40], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x30], xmm0); + sub(B, -96); + align(4); + +L(l3c0); + test(M, 0x1); + jle(l438, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + vpmovsxbd(ymm7, xmm1); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbd(ymm7, xmm2); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l438); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(l480); + cmp(N, 0x20); + jl(l89c, T_NEAR); + align(4); + +L(l48c); + mov(A1, A); + add(A, 0x20); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x2); + jle(l6a8, T_NEAR); + align(4); + +L(l4c8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l4c8, T_NEAR); + align(4); + +L(l6a8); + test(M, 0x2); + jle(l7b4, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + pmovsxbw(xmm5, xmm1); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l850, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbd(xmm5, xmm1); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm1, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm1, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm1, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l850); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(l48c, T_NEAR); + align(4); + +L(l89c); + cmp(N, 0x10); + jl(lae8, T_NEAR); + align(4); + +L(l8a8); + mov(A1, A); + add(A, 0x10); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x2); + jle(l9d0, T_NEAR); + align(4); + +L(l8d0); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l8d0, T_NEAR); + align(4); + +L(l9d0); + test(M, 0x2); + jle(la64, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + pmovsxbw(xmm5, xmm2); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(la64); + test(M, 0x1); + jle(lab8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lab8); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l8a8, T_NEAR); + align(4); + +L(lae8); + cmp(N, 0x8); + jl(ld78, T_NEAR); + align(4); + +L(laf4); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(lc30, T_NEAR); + align(4); + +L(lb14); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(lb14, T_NEAR); + align(4); + +L(lc30); + test(M, 0x4); + jle(lcc8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(lcc8); + test(M, 0x2); + jle(ld1c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(ld1c); + test(M, 0x1); + jle(ld54, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(ld54); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(laf4, T_NEAR); + align(4); + +L(ld78); + cmp(N, 0x4); + jl(lf3c, T_NEAR); + align(4); + +L(ld84); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(le58, T_NEAR); + align(4); + +L(ld9c); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(ld9c, T_NEAR); + align(4); + +L(le58); + test(M, 0x4); + jle(lebc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lebc); + test(M, 0x2); + jle(lef8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(lef8); + test(M, 0x1); + jle(lf1c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(ld84, T_NEAR); + align(4); + +L(lf3c); + cmp(N, 0x2); + jl(l111a, T_NEAR); + align(4); + +L(lf48); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l1024, T_NEAR); + align(4); + +L(lf60); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(lf60, T_NEAR); + align(4); + +L(l1024); + test(M, 0x4); + jle(l1090, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l1090); + test(M, 0x2); + jle(l10d4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l10d4); + test(M, 0x1); + jle(l10fc, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l10fc); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(lf48, T_NEAR); + align(4); + +L(l111a); + cmp(N, 0x1); + jl(l12bc, T_NEAR); + align(4); + +L(l1124); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l11d4, T_NEAR); + align(4); + +L(l113c); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l113c, T_NEAR); + align(4); + +L(l11d4); + test(M, 0x4); + jle(l1234, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l1234); + test(M, 0x2); + jle(l1278, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l1278); + test(M, 0x1); + jle(l129c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l129c); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l1124, T_NEAR); + align(4); + +L(l12bc); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp new file mode 100644 index 0000000000..a4f4ff09c6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp @@ -0,0 +1,3163 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1750; +Xbyak::Label l1b6c; +Xbyak::Label l1e14; +Xbyak::Label l20; +Xbyak::Label l2068; +Xbyak::Label l226c; +Xbyak::Label l22b8; +Xbyak::Label l22c4; +Xbyak::Label l22f4; +Xbyak::Label l26b4; +Xbyak::Label l28cc; +Xbyak::Label l2a2c; +Xbyak::Label l2b5c; +Xbyak::Label l2c64; +Xbyak::Label l2c94; +Xbyak::Label l2ca0; +Xbyak::Label l2cc8; +Xbyak::Label l2eac; +Xbyak::Label l2fc0; +Xbyak::Label l3078; +Xbyak::Label l3118; +Xbyak::Label l319c; +Xbyak::Label l31c0; +Xbyak::Label l31cc; +Xbyak::Label l31ec; +Xbyak::Label l32e4; +Xbyak::Label l3378; +Xbyak::Label l33dc; +Xbyak::Label l3434; +Xbyak::Label l347c; +Xbyak::Label l349c; +Xbyak::Label l34a8; +Xbyak::Label l34c8; +Xbyak::Label l3558; +Xbyak::Label l35b0; +Xbyak::Label l35f4; +Xbyak::Label l3638; +Xbyak::Label l366c; +Xbyak::Label l368a; +Xbyak::Label l3694; +Xbyak::Label l36a8; +Xbyak::Label l36ec; +Xbyak::Label l3728; +Xbyak::Label l3760; +Xbyak::Label l3794; +Xbyak::Label l37b8; +Xbyak::Label l37d8; +Xbyak::Label l5cc; +Xbyak::Label l6c; +Xbyak::Label l968; +Xbyak::Label lc80; +Xbyak::Label lf1c; +Xbyak::Label lf64; +Xbyak::Label lf70; +Xbyak::Label lfb4; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lf64, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x3); + jle(l5cc, T_NEAR); + align(4); + +L(l6c); + vmovq(xmm0, qword[A1-0x80]); + vmovq(xmm1, qword[A1+LDA*1-0x80]); + vmovq(xmm2, qword[A1+LDA*2-0x80]); + vmovq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B+0x40], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x70], xmm2); + vmovdqu(xword[B+0x50], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x60], xmm0); + vmovdqu(xword[B+0x60], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x50], xmm2); + vmovdqu(xword[B+0x70], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B+0x80], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x30], xmm2); + vmovdqu(xword[B+0x90], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x20], xmm0); + vmovdqu(xword[B+0xa0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x10], xmm2); + vmovdqu(xword[B+0xb0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0xc0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x10], xmm2); + vmovdqu(xword[B+0xd0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x20], xmm0); + vmovdqu(xword[B+0xe0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x30], xmm2); + vmovdqu(xword[B+0xf0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -8); + sub(B, -384); + dec(I); + jg(l6c, T_NEAR); + align(4); + +L(l5cc); + test(M, 0x4); + jle(l968, T_NEAR); + vmovd(xmm0, dword[A1-0x80]); + vmovd(xmm1, dword[A1+LDA*1-0x80]); + vmovd(xmm2, dword[A1+LDA*2-0x80]); + vmovd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x80], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x60], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x50], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x40], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B+0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -4); + sub(B, -192); + align(4); + +L(l968); + test(M, 0x2); + jle(lc80, T_NEAR); + mov(ax, word[A1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(lc80); + test(M, 0x1); + jle(lf1c, T_NEAR); + mov(al, byte[A1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(lf64); + cmp(N, 0x20); + jl(l22b8, T_NEAR); + align(4); + +L(lf70); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x4); + jle(l1750, T_NEAR); + align(4); + +L(lfb4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x80], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x90], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0xa0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0xb0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0xc0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0xd0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0xe0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0xf0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(lfb4, T_NEAR); + align(4); + +L(l1750); + test(M, 0x8); + jle(l1b6c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l1b6c); + test(M, 0x4); + jle(l1e14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l1e14); + test(M, 0x2); + jle(l2068, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l2068); + test(M, 0x1); + jle(l226c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l226c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(lf70, T_NEAR); + align(4); + +L(l22b8); + cmp(N, 0x10); + jl(l2c94, T_NEAR); + align(4); + +L(l22c4); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x4); + jle(l26b4, T_NEAR); + align(4); + +L(l22f4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l22f4, T_NEAR); + align(4); + +L(l26b4); + test(M, 0x8); + jle(l28cc, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l28cc); + test(M, 0x4); + jle(l2a2c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l2a2c); + test(M, 0x2); + jle(l2b5c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l2b5c); + test(M, 0x1); + jle(l2c64, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2c64); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l22c4, T_NEAR); + align(4); + +L(l2c94); + cmp(N, 0x8); + jl(l31c0, T_NEAR); + align(4); + +L(l2ca0); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l2eac, T_NEAR); + align(4); + +L(l2cc8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2cc8, T_NEAR); + align(4); + +L(l2eac); + test(M, 0x8); + jle(l2fc0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l2fc0); + test(M, 0x4); + jle(l3078, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3078); + test(M, 0x2); + jle(l3118, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3118); + test(M, 0x1); + jle(l319c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l319c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2ca0, T_NEAR); + align(4); + +L(l31c0); + cmp(N, 0x4); + jl(l349c, T_NEAR); + align(4); + +L(l31cc); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l32e4, T_NEAR); + align(4); + +L(l31ec); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l31ec, T_NEAR); + align(4); + +L(l32e4); + test(M, 0x8); + jle(l3378, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3378); + test(M, 0x4); + jle(l33dc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l33dc); + test(M, 0x2); + jle(l3434, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3434); + test(M, 0x1); + jle(l347c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l347c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l31cc, T_NEAR); + align(4); + +L(l349c); + cmp(N, 0x2); + jl(l368a, T_NEAR); + align(4); + +L(l34a8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l3558, T_NEAR); + align(4); + +L(l34c8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l34c8, T_NEAR); + align(4); + +L(l3558); + test(M, 0x8); + jle(l35b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l35b0); + test(M, 0x4); + jle(l35f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l35f4); + test(M, 0x2); + jle(l3638, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3638); + test(M, 0x1); + jle(l366c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l366c); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l34a8, T_NEAR); + align(4); + +L(l368a); + cmp(N, 0x1); + jl(l37d8, T_NEAR); + align(4); + +L(l3694); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l36ec, T_NEAR); + align(4); + +L(l36a8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l36a8, T_NEAR); + align(4); + +L(l36ec); + test(M, 0x8); + jle(l3728, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3728); + test(M, 0x4); + jle(l3760, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3760); + test(M, 0x2); + jle(l3794, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l3794); + test(M, 0x1); + jle(l37b8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l37b8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l3694, T_NEAR); + align(4); + +L(l37d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp new file mode 100644 index 0000000000..c7f1393c9d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp @@ -0,0 +1,821 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l20; +Xbyak::Label l22c; +Xbyak::Label l340; +Xbyak::Label l3f8; +Xbyak::Label l48; +Xbyak::Label l498; +Xbyak::Label l51c; +Xbyak::Label l540; +Xbyak::Label l54c; +Xbyak::Label l56c; +Xbyak::Label l664; +Xbyak::Label l6f8; +Xbyak::Label l75c; +Xbyak::Label l7b4; +Xbyak::Label l7fc; +Xbyak::Label l81c; +Xbyak::Label l828; +Xbyak::Label l848; +Xbyak::Label l8d8; +Xbyak::Label l930; +Xbyak::Label l974; +Xbyak::Label l9b8; +Xbyak::Label l9ec; +Xbyak::Label la0a; +Xbyak::Label la14; +Xbyak::Label la28; +Xbyak::Label la6c; +Xbyak::Label laa8; +Xbyak::Label lae0; +Xbyak::Label lb14; +Xbyak::Label lb38; +Xbyak::Label lb58; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l540, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l22c, T_NEAR); + align(4); + +L(l48); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l48, T_NEAR); + align(4); + +L(l22c); + test(M, 0x8); + jle(l340, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l340); + test(M, 0x4); + jle(l3f8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3f8); + test(M, 0x2); + jle(l498, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l498); + test(M, 0x1); + jle(l51c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l51c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l540); + cmp(N, 0x4); + jl(l81c, T_NEAR); + align(4); + +L(l54c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l664, T_NEAR); + align(4); + +L(l56c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l56c, T_NEAR); + align(4); + +L(l664); + test(M, 0x8); + jle(l6f8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l6f8); + test(M, 0x4); + jle(l75c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l75c); + test(M, 0x2); + jle(l7b4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l7fc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l7fc); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l54c, T_NEAR); + align(4); + +L(l81c); + cmp(N, 0x2); + jl(la0a, T_NEAR); + align(4); + +L(l828); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l8d8, T_NEAR); + align(4); + +L(l848); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l848, T_NEAR); + align(4); + +L(l8d8); + test(M, 0x8); + jle(l930, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l930); + test(M, 0x4); + jle(l974, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l974); + test(M, 0x2); + jle(l9b8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l9b8); + test(M, 0x1); + jle(l9ec, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l9ec); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l828, T_NEAR); + align(4); + +L(la0a); + cmp(N, 0x1); + jl(lb58, T_NEAR); + align(4); + +L(la14); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(la6c, T_NEAR); + align(4); + +L(la28); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(la28, T_NEAR); + align(4); + +L(la6c); + test(M, 0x8); + jle(laa8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(laa8); + test(M, 0x4); + jle(lae0, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lae0); + test(M, 0x2); + jle(lb14, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(lb14); + test(M, 0x1); + jle(lb38, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(lb38); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(la14, T_NEAR); + align(4); + +L(lb58); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp new file mode 100644 index 0000000000..afe4f1713e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp @@ -0,0 +1,647 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l15c; +Xbyak::Label l1f4; +Xbyak::Label l20; +Xbyak::Label l248; +Xbyak::Label l280; +Xbyak::Label l2a4; +Xbyak::Label l2b0; +Xbyak::Label l2c8; +Xbyak::Label l384; +Xbyak::Label l3e8; +Xbyak::Label l40; +Xbyak::Label l424; +Xbyak::Label l448; +Xbyak::Label l468; +Xbyak::Label l474; +Xbyak::Label l48c; +Xbyak::Label l550; +Xbyak::Label l5bc; +Xbyak::Label l600; +Xbyak::Label l628; +Xbyak::Label l646; +Xbyak::Label l650; +Xbyak::Label l668; +Xbyak::Label l700; +Xbyak::Label l760; +Xbyak::Label l7a4; +Xbyak::Label l7c8; +Xbyak::Label l7e8; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l2a4, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(l15c, T_NEAR); + align(4); + +L(l40); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l40, T_NEAR); + align(4); + +L(l15c); + test(M, 0x4); + jle(l1f4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l1f4); + test(M, 0x2); + jle(l248, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l248); + test(M, 0x1); + jle(l280, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l280); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l2a4); + cmp(N, 0x4); + jl(l468, T_NEAR); + align(4); + +L(l2b0); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(l384, T_NEAR); + align(4); + +L(l2c8); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l2c8, T_NEAR); + align(4); + +L(l384); + test(M, 0x4); + jle(l3e8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3e8); + test(M, 0x2); + jle(l424, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l424); + test(M, 0x1); + jle(l448, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l448); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l2b0, T_NEAR); + align(4); + +L(l468); + cmp(N, 0x2); + jl(l646, T_NEAR); + align(4); + +L(l474); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l550, T_NEAR); + align(4); + +L(l48c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l48c, T_NEAR); + align(4); + +L(l550); + test(M, 0x4); + jle(l5bc, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l5bc); + test(M, 0x2); + jle(l600, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l600); + test(M, 0x1); + jle(l628, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l628); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l474, T_NEAR); + align(4); + +L(l646); + cmp(N, 0x1); + jl(l7e8, T_NEAR); + align(4); + +L(l650); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l700, T_NEAR); + align(4); + +L(l668); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l668, T_NEAR); + align(4); + +L(l700); + test(M, 0x4); + jle(l760, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l760); + test(M, 0x2); + jle(l7a4, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l7a4); + test(M, 0x1); + jle(l7c8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l7c8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l650, T_NEAR); + align(4); + +L(l7e8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp new file mode 100644 index 0000000000..4fc11afcbc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp @@ -0,0 +1,116 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "../f32/ref_gemm_f32.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC; + size_t sizeA = AisN ? lda * k : lda * m; + size_t sizeB = BisN ? ldb * n : ldb * k; + size_t sizeC = ldc * n; + + double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K); + double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K); + double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K); + + if (utils::any_null(dA, dB, dC)) { + free(dA); + free(dB); + free(dC); + return mkldnn_out_of_memory; + } + + auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; }; + auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; }; + + auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; }; + auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; }; + + const int a_rows = AisN ? m : k; + const int a_cols = AisN ? k : m; + mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) { + da_setter(i, j, + static_cast(ia_accessor(i, j)) + static_cast(ao[0])); + }); + + const int b_rows = BisN ? k : n; + const int b_cols = BisN ? n : k; + mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) { + db_setter(i, j, + static_cast(ib_accessor(i, j)) + static_cast(bo[0])); + }); + double one = 1.0, zero = 0.0; + ref_gemm(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero, + dC, LDC, nullptr); + + auto i2d = [=] (int32_t v) { return static_cast(v); }; + auto f2d = [=] (float v) { return static_cast(v); }; + + mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) { + double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]); + double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc])) + + f2d(*alpha) * dC[i + j * ldc] + coffset; + C[i + j * ldc] = math::out_round(math::saturate(val)); + }); + + free(dA); + free(dB); + free(dC); + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm_s8x8s32( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +template mkldnn_status_t ref_gemm_s8x8s32( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp new file mode 100644 index 0000000000..6c0370ae99 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_S8X8S32_HPP +#define REF_GEMM_S8X8S32_HPP + +#include + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp new file mode 100644 index 0000000000..de1035f3b2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" +#include "nstl.hpp" +#include "math_utils.hpp" + +#include "../gemm.hpp" +#include "jit_avx512_core_gemm_s8u8s32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void compensation_init(const char *offsetC, int32_t *compensation, int len, + const int32_t *oc) { + bool OCisC = (*offsetC == 'C' || *offsetC == 'c'); + bool OCisF = (*offsetC == 'F' || *offsetC == 'f'); + + if (OCisF && (*oc) != 0) { + for (int i = 0; i < len; i++) + compensation[i] = *oc; + } else if (OCisC) { + for (int i = 0; i < len; i++) + compensation[i] = oc[i]; + } else { + parallel_nd(len, [=](int i) { compensation[i] = 0; }); + } +} + +void compensation_compute(bool transa, int m, int k, float alpha, + const int8_t *a, int lda, int32_t *compensation) { + if (!transa) { + const int L2_cache_size = get_cache_size(2, true); + const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1); + const int npanels = k / blocking_factor; + const bool has_tile = k % blocking_factor > 0; + + parallel_nd(npanels, m, [&](int j, int i) { + int32_t val = 0; + for (int jb = 0; jb < blocking_factor; jb++) { + val += a[(i + (ptrdiff_t)j * blocking_factor * lda) + + (ptrdiff_t)jb * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + + if (has_tile) { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = npanels * blocking_factor; j < k; j++) { + val += a[i + (ptrdiff_t)j * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + } + } else { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = 0; j < k; j++) { + val += a[j + (ptrdiff_t)i * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + compensation[i] += val; + }); + } +} + +void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8, + const int8_t *b_s8, int ldb_s8) { + const int b_cols = transb ? k : n; + + parallel_nd(b_cols, [=](int j) { + const int b_rows = transb ? n : k; + + uint8_t *pb_u8 = b_u8 + j * ldb_u8; + const int8_t *pb_s8 = b_s8 + j * ldb_s8; + + for (int i = 0; i < b_rows; i++) { + (*pb_u8) = (*pb_s8) + 128; + pb_u8++; + pb_s8++; + } + }); +} + +/** + * gemm_s8s8s32 operation is defined as follows: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation + * + * where + * - compensation is a vector of length m that contains computed compensation + * that may contain C_offset if applicable. The compensation is applied inside + * gemm_s8u8s32 as a C_offset + * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128 + * + * What is the compensation: + * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset = + * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset + * compensation = -alpha * op(A) * B_shift + * Since B_shift is a matrix, every element of which is equal to 128 then + * - if op(A) = A: compensation contains sum of the elements in each row + * scaled by -128 * alpha + * - if op(A) = A**T: compensation contains sum of the elements in each column + * scaled by -128 * alpha + * + * The rest of parameters is described in mkldnn.h + */ +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) { + if (*oa != 0 || *ob != 0) return mkldnn_unimplemented; + + int M = *m, N = *n, K = *k; + bool transa = (*transA == 'T' || *transA == 't'); + bool transb = (*transB == 'T' || *transB == 't'); + int ld = transb ? N : K; + + uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64); + int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64); + + if (utils::any_null(b_u8, compensation)) { + free(b_u8); + free(compensation); + return mkldnn_out_of_memory; + } + + compensation_init(offsetC, compensation, M, oc); + compensation_compute(transa, M, K, *alpha, a, *lda, compensation); + copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb); + + gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8, + &ld, ob, beta, c, ldc, compensation); + + if ((*offsetC == 'R' || *offsetC == 'r')) + parallel_nd(M, N, + [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; }); + + free(b_u8); + free(compensation); + + return mkldnn_success; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp new file mode 100644 index 0000000000..03a3d2f7e0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_GEMM_S8S8S32_HPP +#define SIMPLE_GEMM_S8S8S32_HPP + +#include +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); +} +} +} + +#endif // SIMPLE_GEMM_S8S8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp new file mode 100644 index 0000000000..604a728b47 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp @@ -0,0 +1,307 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "gemm_convolution.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + const int K = jcp.ic * jcp.ks; + const int N = jcp.oc; + + if (jcp.im2col_sz && jcp.id != 1) + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow; + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, + nb_oh, owb, nb_ow); + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const data_t *_src = src + (n * jcp.ngroups + g) * src_step; + const data_t *_weights = weights + g * weights_g_size; + data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, oh, h_step, ow, w_step); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od); + } + + const data_t one = 1.0; + + const int m = h_step * w_step; + const int LDA = jcp.im2col_sz ? m : M; + data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow; + + extended_sgemm("N", "N", &m, &N, &K, &one, + jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K, + &this->beta_, _dst, &M); + + data_t *d = _dst; + if (eltwise_) { + // fast branch for ReLU case + if (eltwise_->alg_ == alg_kind::eltwise_relu) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_; + } + }); + } else { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + d_[oS] = eltwise_->compute_scalar(d_[oS]); + } + }); + } + } else if (jcp.with_bias) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = bias[g * jcp.oc + oc]; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + } + }); + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh, + owb, nb_ow); + } + }); +} + +void gemm_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int m = jcp.os; + const int K = jcp.oc; + const int N = jcp.ic * jcp.ks; + const int LDC = jcp.im2col_sz ? m : M; + + const size_t work_amount = (size_t)jcp.ngroups * jcp.mb; + + if (jcp.id > 1) { + const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step); + parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; }); + } + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{0}, n{0}; + size_t start = 0, end = 0; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); + for (size_t iwork = start; iwork < end; ++iwork) { + + data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step; + const data_t *_weights = weights + g * weights_g_size; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g) + *dst_step + od * m; + + const data_t zero = 0.0, one = 1.0; + extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M, + _weights, &N, &zero, + jcp.im2col_sz ? _col:_diff_src + od * m, &LDC); + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::col2im(jcp, _col, + _diff_src); + else + jit_gemm_convolution_utils::col2im_3d(jcp, _col, + _diff_src, od); + } + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb); + } + }); +} + +void gemm_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + auto wei_reduction = scratchpad(ctx).get(key_conv_wei_reduction); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int K = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * K; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int k = jcp.os; + const int N = jcp.oc; + const int M = jcp.ic * jcp.ks; + const int LDA = jcp.im2col_sz ? k : K; + + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + int ithr_g, nthr_g, ithr_mb, nthr_mb; + size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0}; + + const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; + jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, + mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); + + assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); + const int need_reduction = nthr_mb != 1; + + if (ithr_g != -1 && ithr_mb != -1) { + balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); + balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); + + assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); + + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + data_t *weights_reduce_base = wei_reduction + + ithr_g * nthr_mb * weights_g_size; + data_t *weights_reduce = weights_reduce_base + + ithr_mb * weights_g_size; + + for (size_t g = g_start; g < g_end; ++g) { + data_t *_diff_weights = need_reduction + ? weights_reduce : (diff_weights + g * weights_g_size); + for (size_t mb = mb_start; mb < mb_end; ++mb) { + const data_t *_src = src + (mb*jcp.ngroups+g)*src_step; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + + (mb*jcp.ngroups+g)*dst_step + od * k; + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, 0, jcp.oh, 0, jcp.ow); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, + _col, od); + } + + const data_t zero = 0.0, one = 1.0; + extended_sgemm( + "T", "N", &M, &N, &k, &one, + jcp.im2col_sz ? _col : _src + od * k, + &LDA, _diff_dst, &K, + mb == mb_start && od == 0 ? &zero : &one, + _diff_weights, &M); + } + } + } + if (need_reduction) { + mkldnn_thr_barrier(); + data_t *weights_base = diff_weights + g_start * weights_g_size; + jit_gemm_convolution_utils::bwd_weights_reduction_par( + ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base); + } + } else + if (need_reduction) { mkldnn_thr_barrier(); } + }); + + if (jcp.with_bias) { + parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) { + data_t db = 0; + size_t offset_ = (size_t)g * dst_step + (size_t)oc * K; + for (int mb = 0; mb < jcp.mb; ++mb) + { + size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step; + for (int od = 0; od < jcp.od; ++od) + for (int oh = 0; oh < jcp.oh; ++oh) + PRAGMA_OMP_SIMD(reduction(+:db)) + for (int ow = 0; ow < jcp.ow; ++ow) { + db += diff_dst[offset]; + offset++; + } + } + diff_bias[g*jcp.oc+oc] = db; + }); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp new file mode 100644 index 0000000000..302e46369a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp @@ -0,0 +1,250 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_HPP +#define CPU_JIT_GEMM_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "gemm_convolution_utils.hpp" +#include "gemm/gemm.hpp" +#include "ref_eltwise.hpp" + +#include "cpu_convolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct gemm_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + + bool post_ops_ok() const { + auto const &po = attr()->post_ops_; + auto is_eltwise = [&](int idx) + { return po.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); }; + + switch (po.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + return false; + } + }; + + gemm_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , eltwise_(nullptr) + { + const auto &post_ops = pd()->attr()->post_ops_; + const data_t one = 1.0, zero = 0.0; + beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero; + + const int entry_idx = post_ops.find(primitive_kind::eltwise); + if (entry_idx != -1) eltwise_ = new ref_eltwise_scalar_fwd_t( + post_ops.entry_[entry_idx].eltwise); + } + + ~gemm_convolution_fwd_t() { delete eltwise_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + data_t beta_; + + ref_eltwise_scalar_fwd_t* eltwise_; +}; + +struct gemm_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct gemm_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*diff_weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), diff_weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp new file mode 100644 index 0000000000..f133b1e62b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp @@ -0,0 +1,771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "cpu_isa_traits.hpp" + +#include "gemm_convolution_utils.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; +using namespace prop_kind; +using namespace data_type; + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od) +{ + const size_t OHW = jcp.oh * jcp.ow; + const size_t im_step = jcp.ih * jcp.iw * jcp.id; + const size_t col_step = jcp.ks * OHW; + + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict im_loc = im + ic * im_step; + float *__restrict col_loc = col + ic * col_step; + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW; + if (id < 0 || id >= jcp.id) { + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + + col_[col_idx] = 0; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } else { + const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw; + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + const size_t im_idx = ih * jcp.iw + iw; + + col_[col_idx] = im_[im_idx]; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } + id += (1 + jcp.dilate_d); + } + }); +} + +/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb) { + const size_t im_step = jcp.is; + const size_t col_step = jcp.ks * hb * wb; + if (jcp.stride_w == 1) { + // Generated code is more optimized for stride_w == 1 + // because innermost loop is by width + auto ker = [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ + = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else { + const size_t im_idx = ih * jcp.iw + iw; + col_[ow] = im_[im_idx]; + } + } + } + }; + + if (jcp.outer_threading) { + for (int ic = 0; ic < jcp.ic; ic++) + for (int kh = 0; kh < jcp.kh; kh++) + for (int kw = 0; kw < jcp.kw; kw++) + for (int oh = 0; oh < hb; oh++) + ker(ic, kh, kw, oh); + } + else { + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker); + } + } else if (jcp.ic == 1) { + parallel_nd(jcp.kh, hb, [&](int kh, int oh) { + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + col[col_idx] = 0; + } + } + else + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col[col_idx] = 0; + else + col[col_idx] = im[im_idx]; + } + } + }); + } else { + + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, + [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ = col + ic * col_step + + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else + col_[ow] = im_[im_idx]; + } + } + }); + } +} + +inline int limit(int low, int upper, int value) { + return nstl::max(low, nstl::min(upper, value)); +} + +/* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */ +template +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws, + int wb) { + uint8_t shift = jcp.signed_input ? 128 : 0; + const int dh = 1 + jcp.dilate_h; + const int dw = 1 + jcp.dilate_w; + const int sh = jcp.stride_h; + const int sw = jcp.stride_w; + const int im_iw_stride = jcp.ic * jcp.ngroups; + const int im_ih_stride = jcp.iw * im_iw_stride; + const int tp = jcp.t_pad; + const int lp = jcp.l_pad; + + if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { + /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ + const int hp = hs - tp; + const int wp = ws - lp; + const int ih_start = limit(0, jcp.ih, hp); + const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh); + const int iw_start = limit(0, jcp.iw, wp); + const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw); + + const int ihb = ih_end - ih_start; + const int iwb = iw_end - iw_start; + + const int imtr_ic_stride = ihb * iwb; + const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start; + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift; + for (int ih = ih_start; ih < ih_end; ih++) { + const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride; + const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb; + for (int iw = iw_start; iw < iw_end; iw++) + imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride]; + } + } + + const int col_ic_str = hb * wb; + const int col_kw_stride = jcp.ic * col_ic_str; + const int col_kh_stride = jcp.kw * col_kw_stride; + + const int oh_init = ih_start - hp; + const int ow_init = iw_start - wp; + for (int kh = 0; kh < jcp.kh; kh++) { + const ptrdiff_t col_idx_kh = kh * col_kh_stride; + const int oh_kh = oh_init - kh; + const int oh_start = limit(0, hb, oh_kh); + const int oh_end = limit(0, hb, oh_kh + ihb); + for (int kw = 0; kw < jcp.kw; kw++) { + const ptrdiff_t col_idx_kw + = col_idx_kh + kw * jcp.ic * col_ic_str; + const int ow_kw = ow_init - kw; + const int imtr_shift = oh_kh * iwb + ow_kw; + const int ow_start = limit(0, wb, ow_kw); + const int ow_end = limit(0, wb, ow_kw + iwb); + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; + const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; + for (int oh = 0; oh < oh_start; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_start; oh < oh_end; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; + for (int ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = shift; + for (int ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow] + shift; + for (int ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_end; oh < hb; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + } + } + } + } else { + parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, + [&](int kh, int kw, int ic, int oh) { + const int hp = tp - kh * dh; + const int ih = (oh + hs) * sh - hp; + const ptrdiff_t col_idx_base + = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb; + if (ih < 0 || ih >= jcp.ih) + for (int ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = shift; + else { + const int wp = lp - kw * dw; + const int ow_start = limit(0, wb, div_up(wp, sw) - ws); + const int ow_end + = limit(0, wb, div_up(jcp.iw + wp, sw) - ws); + for (int ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = shift; + const int iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (int ow = ow_start; ow < ow_end; ow++) { + const int iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx] + shift; + } + for (int ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = shift; + } + }); + } +} + +template void im2col_u8(const jit_gemm_conv_conf_t &jcp, + const int8_t *__restrict im, int8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); +template void im2col_u8(const jit_gemm_conv_conf_t &jcp, + const uint8_t *__restrict im, uint8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); + +/* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */ +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im) +{ + parallel(0, [&](const int ithr, const int nthr) { + int h_nthr = nstl::min(jcp.ih, nthr); + int w_nthr = nstl::min(jcp.iw, nthr / h_nthr); + int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0; + if (ithr < h_nthr * w_nthr) { + h_ithr = ithr / w_nthr; + w_ithr = ithr % w_nthr; + balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e); + balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e); + } else { + h_ithr = w_ithr = -ithr; + h_s = h_e = w_s = w_e = -1; + } + + for (int ih = h_s; ih < h_e; ++ih) { + for (int iw = w_s; iw < w_e; ++iw) { + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0; + } + } + } + + // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh] + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h + - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < h_s || ih >= h_e) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w + - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < w_s || iw >= w_e) continue; + + const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh + + kh) * jcp.kw + kw) * jcp.ic; + const size_t im_idx + = (ih * jcp.iw + iw) * jcp.ic; + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[im_idx + ic] += col[col_idx + ic]; + } + } + } + } + } + }); +} + +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od) +{ + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os; + float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id; + + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + if (id < 0 || id >= jcp.id) { + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + continue; + } + + float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw; + + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + }} + }} + + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + } + }); +} + +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) { + const size_t col_step = jcp.ks * jcp.os; + const size_t im_step = jcp.ih * jcp.iw; + const int iS = jcp.ih * jcp.iw; + + parallel_nd(jcp.ic, [&](int ic) { + float *__restrict im_ = im + ic * im_step; + const float *__restrict col_ = col + ic * col_step; + PRAGMA_OMP_SIMD() + for (int is = 0; is < iS; ++is) im_[is] = 0.; + + for (int kh = 0; kh < jcp.kh; ++kh) { + for (int oh = 0; oh < jcp.oh; ++oh) { + const int ih = + oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < jcp.ow; ++ow) { + const int iw = + ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + } + } + } + } + }); +} + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads) { + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + const int is_1d = ndims == 3; + const int is_3d = ndims == 5; + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = is_3d ? src_d.dims()[2] : 1; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.od = is_3d ? dst_d.dims()[2] : 1; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = is_3d ? cd.padding[0][0] : 0; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_d = is_3d ? cd.strides[0] : 1; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_d = is_3d ? cd.dilates[0] : 0; + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef + || cd.diff_bias_desc.format_kind != format_kind::undef; + + jcp.is = jcp.ih * jcp.iw; + jcp.os = jcp.oh * jcp.ow; + jcp.ks = jcp.kh * jcp.kw * jcp.kd; + + jcp.signed_input = src_d.data_type() == data_type::s8; + + jcp.im2col_sz = !everyone_is(true, + jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id, + jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1, + jcp.ks == 1, !jcp.signed_input) + ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0; + + jcp.outer_threading = false; + + bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8) + && weights_d.data_type() == s8; + + const int vlen = mayiuse(avx512_common) + ? cpu_isa_traits::vlen + : mayiuse(avx) + ? cpu_isa_traits::vlen + : mayiuse(sse42) ? cpu_isa_traits::vlen : 4; + const int simd_w = vlen / (is_int8_conv ? 1 : 4); + + const bool is_bwd_d = jcp.prop_kind == backward_data; + const bool is_bwd_w = jcp.prop_kind == backward_weights; + const bool is_fwd = !is_bwd_d && !is_bwd_w; + jcp.oh_block = is_fwd ? jcp.oh : jcp.ih; + jcp.ow_block = is_fwd ? jcp.ow : jcp.iw; + + using namespace memory_tracking::names; + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + + // TODO: maybe mitigate blocking restriction + const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int L2 = get_cache_size(2, true) + / (is_int8_conv ? sizeof(int8_t) : sizeof(float)); + bool is_blocking_applicable = true + && is_fwd && jcp.im2col_sz + && jcp.id == 1 && jcp.od == 1 + && jcp.dilate_h == 0 && jcp.dilate_w == 0 + && !is_depthwise + && wei_size < L2/2; + if (is_blocking_applicable) { + // looking for oh and ow blocking + int h_block{ jcp.oh_block }, w_block{ jcp.ow_block }; + const int ic = jcp.ic; + const int oc = jcp.oc; + const int iw = jcp.iw; + const int ow = jcp.ow; + const int oh = jcp.oh; + const int os = oh * ow; + + // 1. cache requirement + int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); + if (is_int8_conv) { + // Heuristic rule: gemm needed a lot of memory for internal usage + row_size *= 5; + // memory for accumulators + row_size += oc * ow * sizeof(uint32_t); + // memory for transposition + row_size += ic * iw; + } + + h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size))); + if (h_block == 1) { + int col_size = ic * jcp.ks + 2 * (ic + oc); + if (is_int8_conv) { + col_size *= 5; + col_size += oc * sizeof(uint32_t); + col_size += ic; + } + w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size))); + } + + // 2. threading requirement + if (h_block != oh) + h_block = nstl::max(1, rnd_dn(h_block, 4)); + if (w_block != ow) + w_block = nstl::max(1, rnd_dn(w_block, simd_w)); + + float thr_eff = 0.f; + float thr_eff_treshold = 0.9f; + if (w_block == ow) { + do { + int nb_h = div_up(oh, h_block); + size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; + float disb = (float)oh / rnd_up(oh, h_block); + thr_eff = (float)work / rnd_up(work, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff >= thr_eff_treshold) + break; + h_block = rnd_dn(h_block - 4, 4); + } while (h_block > 0); + } + if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block + { + h_block = 1; + int nb_h = oh; + do { + int nb_w = div_up(ow, w_block); + size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; + float disb = (float)ow / rnd_up(ow, w_block); + thr_eff = (float)work_amount / rnd_up(work_amount, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff > thr_eff_treshold) + break; + w_block = rnd_dn(w_block - simd_w, simd_w); + } while (w_block > 0); + } + h_block = nstl::max(1, h_block); + w_block = nstl::max(1, w_block); + const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) { + jcp.oh_block = h_block; + jcp.ow_block = w_block; + jcp.outer_threading = true; + } + // updating jcp.im2col_sz + if (jcp.oh_block != 1) + jcp.ow_block = ow; + jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; + } + // For threading selection in bwd_d we do: + // 1. Rough estimation of efficiency for inner and outer threading. + // 2. Gemm size estimation in assumption that it does not work + // so effectively for small sizes. + // 64K - this is heuristic gemm size per thread threshold. + const int gemm_thrld = 64 * 1024; + + if (is_int8_conv) { + if (is_fwd) { + if (!jcp.outer_threading) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int8_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc); + scratchpad.book(key_conv_gemm_imtr, + sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_d) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int32_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_w) { + assert(!"unimplemented prop_kind"); + return status::unimplemented; + } + } else { + if (is_fwd) { + if (!jcp.outer_threading) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work_amount + = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w); + const float inner_thr_eff = (float)inner_work_amount + / rnd_up(inner_work_amount, max_threads); + jcp.outer_threading = jcp.os / max_threads < 512 + && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + } else if (is_bwd_d) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff = (float)inner_work + / rnd_up(inner_work, max_threads); + jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64) + && (jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } else if (is_bwd_w) + jcp.outer_threading = jcp.os / max_threads < 256 + && (jcp.mb != 1 || jcp.ngroups > 2); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(float) * jcp.nthr * jcp.im2col_sz); + + if (is_bwd_w) { + jcp.need_wei_reduction = mkldnn_thr_syncable() + ? jcp.mb != 1 && jcp.nthr != 1 : false; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size()); + } + } + + return status::success; +} + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g, + int &nthr_g, int &ithr_mb, int &nthr_mb) { + nthr_g = nstl::min(ngroups, nthr); + nthr_mb = nstl::min(mb, nthr / nthr_g); + if (ithr / nthr_mb >= ngroups) { + ithr_g = ithr_mb = -1; + } else { + ithr_g = ithr / nthr_mb; + ithr_mb = ithr % nthr_mb; + } +} + +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights) { + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + size_t weights_start{0}, weights_end{0}; + balance211(weights_g_size, nthr, ithr, weights_start, weights_end); + + for (int i = 0; i < nthr; ++i) { + const float *ws_i = weights_reduce_ws + i * weights_g_size; + for (size_t s = weights_start; s < weights_end; ++s) + weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s]; + } +} + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp new file mode 100644 index 0000000000..e006789344 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp @@ -0,0 +1,66 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP +#define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_engine.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od); +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb); +template +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T* __restrict imtr, uint8_t *__restrict col, + int hs, int hb, int ws, int wb); + +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im); +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od); +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im); + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads); + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, + int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb); +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights); + +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp new file mode 100644 index 0000000000..2872122f0d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp @@ -0,0 +1,156 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" + +#include "gemm_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::primitive_kind; + +template +void gemm_inner_product_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = !memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights, + wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias); + + if (do_relu) { + float nslope = post_ops.entry_[0].eltwise.alpha; + parallel_nd(MB, OC, [&](int mb, int oc) { + size_t dst_off = mb * OC + oc; + if (dst[dst_off] < 0) + dst[dst_off] *= nslope; + }); + } +} + +template +void gemm_inner_product_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights, + wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC); +} + +template +void gemm_inner_product_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + diff_dst += diff_dst_d.offset0(); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->diff_weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + if (wei_tr) + extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC, + &beta, diff_weights, &OC); + else + extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC, + &beta, diff_weights, &IC); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + constexpr int blksize = 8; + const int OC_blocks = OC / blksize; + const int rem_OC = OC % blksize; + parallel(0, [&](const int ithr, const int nthr) { + int oc_st{0}, oc_e{0}; + balance211(OC_blocks, nthr, ithr, oc_st, oc_e); + oc_st = oc_st * blksize; + oc_e = oc_e * blksize; + + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] = diff_dst[oc]; + } + + for (int mb = 1; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + + if (rem_OC != 0 && ithr == nthr-1) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) + diff_bias[oc] = diff_dst[oc]; + for (int mb = 1; mb < MB; ++mb) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + } + }); + } +} + +template struct gemm_inner_product_fwd_t; +template struct gemm_inner_product_bwd_data_t; +template struct gemm_inner_product_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp new file mode 100644 index 0000000000..acf0a49b9a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp @@ -0,0 +1,157 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_GEMM_INNER_PRODUCT_HPP +#define CPU_GEMM_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct gemm_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type, + src_md()->data_type, + weights_md()->data_type, + dst_md()->data_type, + with_bias() ? weights_md(1)->data_type : data_type) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct gemm_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_src_md()->data_type, + weights_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(diff_src_md(), weights_md(), + diff_dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct gemm_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + src_md()->data_type, + diff_weights_md()->data_type, + diff_dst_md()->data_type, + with_bias() ? diff_weights_md(1)->data_type : data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(src_md(), diff_weights_md(), + diff_dst_md()); + + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp new file mode 100644 index 0000000000..fed7e4d693 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp @@ -0,0 +1,740 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "math_utils.hpp" + +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace mkldnn::impl::memory_tracking::names; + +template +void _gemm_x8s8s32x_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src_base = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst_base = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base, + scratchpad); + }); +} + +template +_gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::pp_ker_t( + const pd_t *pd) + : ker_(nullptr) + , jcp_(pd->jcp_) + , OC_(pd->jcp_.oc) + , OS_(pd->jcp_.os) + , bias_data_type_(data_type::undef) + , bias_data_type_size_(0) + , scale_idx_mult_(0) + , do_bias_(false) + , do_relu_(false) + , do_sum_(false) +{ + using namespace types; + + const auto dst_md = memory_desc_wrapper(pd->dst_md()); + dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1); + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + + int entry_idx = -1; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + entry_idx = idx; + break; + } + } + do_relu_ = entry_idx >= 0; + + do_signed_scaling_ = jcp_.signed_input; + + do_sum_ = post_ops.contain(primitive_kind::sum, 0); + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + const size_t vlen_start + = cpu_isa_traits::vlen / sizeof(float); + + for (size_t i = vlen_start; i > 0; i--) { + if (OC_ % i == 0) { + vlen_ = i; + break; + } + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs + return; + else + generate(); +} + +template +void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask_short = r10; + Reg64 reg_rem_mask_vlen = r11; + Opmask kreg_rem_mask_short = k1; + Opmask kreg_rem_mask_vlen = k3; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = vlen_; + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + Zmm vreg_sum_scale = Zmm(3); + Zmm vreg_signed_scale = Zmm(4); + + size_t def_unroll = 4; + size_t max_unroll = 12; + size_t zmm_step = 2; + if (do_sum_) { + max_unroll = 8; + zmm_step = 3; + } + + auto vreg_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 0); + }; + auto vreg_bias = [&](int idx) { + return Zmm(5 + idx * zmm_step + 1); + }; + auto vreg_prev_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 2); + }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); + vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); + +#undef PARAM_OFF + + mov(reg_rem_mask_vlen, 1); + shl(reg_rem_mask_vlen, vlen); + sub(reg_rem_mask_vlen, 1); + kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen); + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply sum (if any), + // bias (if any), scaling, and relu (if any); + // then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask_short; + else + vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen; + vmovups(vreg_scale_, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask_short; + else + vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_signed_scaling_) + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask_short; + else + vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + + if (do_sum_) + { + auto vreg_prev_dst_ = vreg_prev_dst(idx); + if (apply_mask) + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short; + else + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen; + + switch (dst_type) { + case data_type::f32: + case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break; + case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break; + case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break; + default: assert(!"unsupported data type"); + } + if (dst_type != data_type::f32) + vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx)); + + vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); + } + + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + if (dst_type == data_type::u8) + vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indexed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t)); + }; + + // <--------- OC ---------------> + // + // ^ ................+..............+-------------+....................... + // | . : not accessed |Prologue loop| . + // | . +--------------+-------------+ . + // . | | . + // O . | Main loop (unrolled) | . + // S . | | . + // . +--------------+-------------+ . + // | . | Epilogue loop|not accessed : . + // v ................+--------------+.............+....................... + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask_short, 1); + // cl == reg_tmp because reg_tmp <= vlen here + shl(reg_rem_mask_short, cl); + sub(reg_rem_mask_short, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } + else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask_short, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask_short, 1); + shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen + sub(reg_rem_mask_short, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode(); +} + +template +void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::operator () + (dst_data_t *dst, const acc_data_t *acc, const char *bias, + const float *scales, float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + size_t os_offset = start / OC_; + args.acc = acc + start; + args.dst = dst + os_offset * dst_os_stride_ + oc_offset; + args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset); + args.nslope = nslope; + args.sum_scale = sum_scale; + args.signed_scale = signed_scale; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } + else { + // Fallback + const size_t first_oc = start % OC_; + const size_t last_oc = (end - 1) % OC_; + const size_t first_os = start / OC_; + const size_t last_os = (end - 1) / OC_; + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * dst_os_stride_ + oc; + + float d = (float)(acc[acc_off]); + if (jcp_.signed_input) + d *= signed_scale; + + if (do_bias_) + d += get_bias(bias, g * jcp_.oc + oc, + bias_data_type_); + + d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; + if (do_sum_) + d += sum_scale * dst[dst_off]; + if (do_relu_ && d < 0) + d *= nslope; + dst[dst_off] = qz_a1b0()(d); + } + } + } +}; + +template +void _gemm_x8s8s32x_convolution_fwd_t:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base, + const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const { + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto src_md = memory_desc_wrapper(pd()->src_md()); + const size_t src_mb_stride = src_md.blk_off(1); + const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto dst_md = memory_desc_wrapper(pd()->dst_md()); + const size_t dst_mb_stride = dst_md.blk_off(1); + const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_sum = post_ops.contain(primitive_kind::sum, 0); + const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0; + + float nslope = 0; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + nslope = e.eltwise.alpha; + break; + } + } + + auto col = scratchpad.get(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + src_data_t *__restrict imtr = scratchpad.get(key_conv_gemm_imtr) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; + + const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc; + const int32_t *_wei_comp = (const int32_t *)(wei_base + offset); + + int g{ 0 }, n{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, + nb_oh, owb, nb_ow); + + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const src_data_t *__restrict src = src_base + n * src_mb_stride + + g * src_g_stride; + const wei_data_t *__restrict wei = wei_base + g * wei_g_stride; + dst_data_t *__restrict dst = + dst_base + n * dst_mb_stride + g * dst_g_stride; + const int32_t *wei_comp = _wei_comp + g * jcp.oc; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::im2col_u8( + jcp, src, imtr, col, oh, h_step, ow, w_step); + + const int M = jcp.oc; + const int K = jcp.ks * jcp.ic; + const int N = h_step * w_step; + const int LDA = M * jcp.ngroups; + const int LDB = jcp.im2col_sz ? N : K; + const char *BT = jcp.im2col_sz ? "T" : "N"; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", + &M, &N, &K, &onef, wei, &LDA, &off_a, + jcp.im2col_sz ? col : (uint8_t *)src, &LDB, &off_b, + &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c); + + auto wei_adj_scale = + (wei_md.extra().flags | memory_extra_flags::scale_adjust) + ? wei_md.extra().scale_adjust : 1.f; + + parallel(0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)N * jcp.oc, nthr, ithr, start, end); + (*pp_ker_)(dst + (oh * jcp.ow + ow) * pp_ker_->dst_os_stride_, + acc, bia_base, scales, nslope, sum_scale, + 1.f / wei_adj_scale, g, start, end); + }); + + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, + owb, nb_ow); + } +} + +template +void _gemm_u8s8s32x_convolution_bwd_data_t:: +execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base, + bia_base, diff_src_base, scratchpad); + }); +} + +template +void _gemm_u8s8s32x_convolution_bwd_data_t:: +execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const +{ + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md()); + const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1); + const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md()); + const size_t diff_src_mb_stride = diff_src_md.blk_off(1); + const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic; + const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1); + + /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */ + const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1); + const float *scales = pd()->attr()->output_scales_.scales_; + const size_t work_amount = jcp.ngroups * jcp.mb; + + auto col = scratchpad.get(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + + int n{0}, g{0}; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); + + for (size_t iwork = start; iwork < end; ++iwork) { + const diff_dst_data_t *diff_dst = diff_dst_base + + n * diff_dst_mb_stride + g * diff_dst_g_stride; + const wei_data_t *wei = wei_base + g * wei_g_stride; + diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride + + g * diff_src_g_stride; + + const int M = jcp.ks * jcp.ic; + const int N = jcp.os; + const int K = jcp.oc; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + const int LD = K * jcp.ngroups; + + gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, + wei, &LD, &off_a, diff_dst, &LD, &off_b, + &zerof, jcp.im2col_sz ? col : acc, &M, &off_c); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::col2im_s32(jcp, col, acc); + + parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) { + float d = (float)acc[is * jcp.ic + ic]; + if (jcp.with_bias) + d += get_bias(bia_base, g * jcp.ic + ic, + pd()->desc()->bias_desc.data_type); + d *= scales[(g * jcp.ic + ic) * scale_idx_mult]; + const size_t diff_src_off = is * diff_src_os_stride + ic; + diff_src[diff_src_off] = + qz_a1b0()(d); + }); + nd_iterator_step(n, jcp.mb, g, jcp.ngroups); + } +} + +using namespace data_type; + +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; + +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; + +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp new file mode 100644 index 0000000000..9e77b890d5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP +#define GEMM_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" +#include "gemm_convolution_utils.hpp" + +#include "gemm/gemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_x8s8s32x_convolution_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, s8, data_type::undef, dst_type, + s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common( + dat_tag(), format_tag::any, dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio); + + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + + bool post_ops_ok() const { + using namespace mkldnn::impl::primitive_kind; + auto const &po = attr()->post_ops_; + auto is_relu = [&](int idx) { + return po.entry_[idx].is_relu(true, false); }; + + switch (po.len_) { + case 0: return true; + case 1: return is_relu(0) || po.contain(sum, 0); + case 2: return po.contain(sum, 0) && is_relu(1); + default: return false; + } + return false; + } + }; + + _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true), pp_ker_(nullptr) + { pp_ker_ = new pp_ker_t(pd()); } + ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_ker_t : jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _gemm_x8s8s32x_convolution_fwd_t::pp_kernel); + pp_ker_t(const pd_t *pd); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, + float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end); + + size_t dst_os_stride_; + + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + float sum_scale; + float signed_scale; + size_t len; + size_t oc_offset; + }; + void(*ker_)(const ker_args *args); + + const jit_gemm_conv_conf_t &jcp_; + size_t OC_; + size_t OS_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + bool do_sum_; + bool do_signed_scaling_; + size_t vlen_; + }; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src_base, const wei_data_t *wei_base, + const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const; + + int nthr_; + pp_ker_t *pp_ker_; + +}; + +template +struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t{ + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_u8s8s32x_convolution_bwd_data_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(dst_type, s8, data_type::undef, u8, s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && attr()->post_ops_.has_default_values() + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(), diff_dst_md(), + mkldnn_get_max_threads()); + } + + virtual bool support_bias() const override { return true; } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + format_tag_t wei_tag() const { + return with_groups() ? format_tag::hwigo : format_tag::hwio; + } + }; + + _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + void execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp new file mode 100644 index 0000000000..1e435a233a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp @@ -0,0 +1,453 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace math; +using namespace format_tag; +using namespace memory_tracking::names; + +template +gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::pp_kernel_t( + const pd_t *pd, bool dst_is_acc) + : ker_(nullptr), OC_(pd->OC()) + , bias_data_type_(data_type::undef), bias_data_type_size_(0) + , scale_idx_mult_(0), do_bias_(false), do_relu_(false) +{ + using namespace types; + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + do_relu_ = post_ops.len_ == 1; + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs since they do not have optimized + // x8s8s32 GEMM anyways. The configuration variables above are used by + // the fallback code. + return; + else + generate(); +} + +template +void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask = r10; + Opmask kreg_rem_mask = k1; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = cpu_isa_traits::vlen / sizeof(float); + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + + auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); }; + auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); +#undef PARAM_OFF + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply bias (if any), scaling, + // and relu (if any); then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask; + vmovups(vreg_scale, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type == data_type::u8) + vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indixed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + }; + + // <-------------------- OC -------------------------------> + // + // ^ +....................+----------------------------------+ + // | : not accessed | Prologue loop | + // | +--------------------+----------------------------------+ + // | | + // M | Main loop (unrolled) | + // B | | + // +--------------------------------+----------------------+ + // | | Epilogue loop | not accessed : + // v +--------------------------------+......................+ + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?) + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here + sub(reg_rem_mask, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t def_unroll = 4; + size_t max_unroll = 13; + + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?) + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16 + sub(reg_rem_mask, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode(); +} + +template +void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::operator ()( + dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + args.dst = dst + start; + args.acc = acc + start; + args.bias = bias + oc_offset * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * oc_offset; + args.nslope = nslope; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } else { + // Fallback + size_t oc = start % OC_; + for (size_t i = start; i < end; i++) { + float d = (float)acc[i]; + float b = get_bias(bias, oc, bias_data_type_); + d = d + b; + d *= scales[oc * scale_idx_mult_]; + if (do_relu_ && d < 0) + d *= nslope; + dst[i] = qz_a1b0()(d); + oc = (oc == OC_ - 1) ? 0 : oc + 1; + } + } +}; + +template +void gemm_x8s8s32x_inner_product_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), oiw, oihw, oidhw, oi); + + const int M = OC; + const int N = MB; + const int K = pd()->IC_total_padded(); + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + acc_data_t *acc = pd()->dst_is_acc_ + ? (acc_data_t *)dst + : scratchpad(ctx).template get(key_iprod_int_dat_in_acc_dt); + + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef, weights, + wei_tr ? &K : &M, &off_a, src, &K, &off_b, &zerof, acc, &M, &off_c); + + if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_ + || pd()->with_bias()) { + const bool force_sequential = MB * OC < 2000; + parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)OC * MB, nthr, ithr, start, end); + (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end); + }); + } +} + +using namespace data_type; + +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp new file mode 100644 index 0000000000..ac6a5c8f85 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp @@ -0,0 +1,166 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_INNER_PRODUCT_HPP +#define GEMM_X8S8S32X_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" +#include "jit_generator.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct gemm_x8s8s32x_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(src_type == data_type::u8 + ? IGEMM_S8U8S32_IMPL_STR + : IGEMM_S8S8S32_IMPL_STR, + gemm_x8s8s32x_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == src_type + && dst_md()->data_type == dst_type + && weights_md()->data_type == s8 + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + if (!ok) return status::unimplemented; + + dst_is_acc_ = utils::one_of(dst_type, s32, f32); + + init_scratchpad(); + + return status::success; + } + + bool dst_is_acc_; + + protected: + status_t set_default_params() { + using namespace format_tag; + if (src_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md_, + utils::pick(ndims() - 2, nc, nwc, nhwc, ndhwc))); + } + if (dst_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md_, nc)); + if (weights_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md_, + utils::pick(ndims() - 2, io, wio, hwio, dhwio))); + } + return inner_product_fwd_pd_t::set_default_params(); + } + + private: + void init_scratchpad() { + if (!dst_is_acc_) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book( + memory_tracking::names::key_iprod_int_dat_in_acc_dt, + sizeof(acc_data_t) * MB() * OC()); + } + } + }; + + gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + { pp_kernel_ = new pp_kernel_t(apd, pd()->dst_is_acc_); } + ~gemm_x8s8s32x_inner_product_fwd_t() { delete pp_kernel_; } + + typedef typename prec_traits::type data_t; + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_kernel_t: jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + gemm_x8s8s32x_inner_product_fwd_t::pp_kernel); + pp_kernel_t(const pd_t *pd, bool dst_is_acc); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end); + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + size_t len; + size_t oc_offset; + }; + void (*ker_)(const ker_args *args); + + size_t OC_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + }; + + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + pp_kernel_t *pp_kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp new file mode 100644 index 0000000000..6fa251d465 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp @@ -0,0 +1,674 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop, bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto vreg_load = [=](int i) { + return Ymm(ur * load_loop_blk + i); + }; + + auto vreg_accum = [=](int i, int j) { + return Ymm(j * load_loop_blk + i); + }; + + auto bias_ptr = [=](int i) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) + { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt]; + }; + + auto output_ptr = [=](int i, int j) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done, init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) + vmovups(vreg_accum(i, j), bias_ptr(i)); + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r = vreg_accum(i, j); + vxorps(r, r, r); + } + + L(init_done); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(vreg_load(i), load_ptr(0, i)); + vbroadcastss(vreg_bcast, bcast_ptr(0, 0)); + }; + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r = vreg_accum(i, j); + vaddps(r, r, output_ptr(i, j)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + vmovups(output_ptr(i, j), vreg_accum(i, j)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + if (mayiuse(avx2)) + vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(vtmp, vreg_bcast, vreg_load(i)); + vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp); + } + if (j == ur - 1 && !(last_block + && u == jcp.reduce_loop_unroll - 1)) + vmovups(vreg_load(i), load_ptr(u + 1, i)); + } + if (j < ur - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1)); + } + if (!last_block || u < jcp.reduce_loop_unroll - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0)); + } + }; + + Label reduce_loop, reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i) { return Ymm(i); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r = diff_bias_reg(i); + vxorps(r, r, r); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(diff_bias_reg(i), diff_bias_ptr(i)); + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) + vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i)); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) + vmovups(diff_bias_ptr(i), diff_bias_reg(i)); + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_avx2_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, 8); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + + const int simd_w = 8; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.ih == jcp.oh && jcp.iw == jcp.ow + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + // TODO: remove this restriction + // optimized 1x1 bwd_w does not support Intel AVX + if (jcp.prop_kind == backward_weights && !mayiuse(avx2)) + return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx2_1x1_conv_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp new file mode 100644 index 0000000000..bfdeb2b18d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp @@ -0,0 +1,110 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP +#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_conv_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32) + + jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_avx2_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using ymm_t = const Xbyak::Ymm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t imm_addr64 = reduce_loop_iter; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + ymm_t vreg_bcast = ymm_t(15); + ymm_t vtmp = ymm_t(14); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp new file mode 100644 index 0000000000..f116ac9056 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp @@ -0,0 +1,545 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx2_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +/* convolution forward */ + +void jit_avx2_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int ndims = dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); + rp.os = p.bcast_dim; + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step = step(jcp.nb_load_blocking, + jcp.nb_load - ocb, jcp.nb_load_blocking_max); + + const int _ocb = g * nb_oc + ocb; + p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + + p.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + nb_ic_blocking * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + + if (ocb == 0) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +/* convolution backward wtr data */ + +void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + const int ndims = diff_dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int load_step = 0; + for (int icb = 0; icb < jcp.nb_load; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = start; iwork < end; iwork += bcast_step) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb = 0; ocb < jcp.nb_reduce; + ocb += jcp.nb_reduce_blocking) { + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, + ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + nb_oc_blocking * jcp.oc_block); + + kernel_->jit_ker(&p); + } + + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + }; + + parallel(0, ker); +} + +/* convolution backward wtr weights */ + +jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t( + const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) +{ + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + reducer_weights_ = + new cpu_reducer_2d_t(pd()->reducer_wei_conf_); + reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); + init_rtus_driver(this); +} + +void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + const int ndims = diff_dst_d.ndims(); + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + const int bcast_work = div_up(nb_ic, nb_ic_blocking); + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + const int load_work = div_up(nb_oc, nb_oc_blocking); + + const int sp_dim = jcp.reduce_dim; + const int mb_sp_work = jcp.mb * sp_dim; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image, + data_t *store_to, size_t store_to_ld, const data_t *diff_dst, + const data_t *src, int ithr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + p.output_stride = store_to_ld * sizeof(float); + const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block; + + int oc_b_step = 0; + for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) { + oc_b_step = step(12, nb_oc_blocking - oc_b, 18); + p.load_dim = oc_b_step * jcp.oc_block; + + int ic_b_step = 0; + for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) { + ic_b_step = step(12, nb_ic_blocking - ic_b, 18); + p.bcast_dim = ic_b_step * jcp.ic_block; + rp.icb = p.bcast_dim / jcp.ic_block; + + p.output_data = store_to + oc_b * store_to_ld + + ic_b * jcp.ic_block * jcp.oc_block; + + /* spatial reduction */ + int sp_step = 0; + for (int sp = sp_start; sp < sp_end; sp += sp_step) { + sp_step = step(sp_step_def, sp_end - sp, 192); + p.reduce_dim = sp_step; + rp.os = p.reduce_dim; + + p.first_last_flag = sp == sp_start && first_image + ? FLAG_REDUCE_FIRST : 0; + + p.load_data = diff_dst + + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + (ic_b * jcp.is + sp) * jcp.ic_block; + if (ndims == 3) + rp.src = src + + iw * src_d.blocking_desc().strides[2]; + else + rp.src = src + + ih * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + + if (oc_b == 0) + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_njobs = rw->balancer().ithr_njobs(ithr); + if (w_njobs == 0) return; + + /* setup: independent work (oc, ic) */ + const int w_job_start = rw->balancer().ithr_job_off(ithr); + int g{0}, load_i{0}, bcast_i{0}; + nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work, + bcast_i, bcast_work); + + /* setup: reduction work (mb, sp) */ + int mb_sp_start{0}, mb_sp_end{0}; + balance211(mb_sp_work, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end); + int img_start{0}, sp_start{0}; + nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim); + + /* independent work */ + for (int iwork = 0; iwork < w_njobs; ++iwork) { + const int oc_b = nb_oc_blocking * load_i; + const int ic_b = nb_ic_blocking * bcast_i; + + const int _ic_b = g * nb_ic + ic_b; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + size_t store_to_ld; + + if (rw->balancer().nthr_per_group_ == 1) { + const size_t off = pd()->with_groups() + ? diff_weights_d.blk_off(g, oc_b, ic_b) + : diff_weights_d.blk_off(oc_b, ic_b); + store_to = &diff_weights[off]; + store_to_ld = jcp.ic * jcp.oc_block; + } else { + const size_t off = iwork * rw->balancer().job_size_; + store_to = + rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off; + store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block; + } + + /* reduction work */ + int img = img_start; + int sp = sp_start; + int sp_step = 0; + for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) + { + sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp); + + const bool first_image = img == img_start; + oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to, + store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)], + &src[src_d.blk_off(img, _ic_b)], ithr); + + sp = 0; + img += 1; + } + + nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i, + bcast_work); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = + rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp new file mode 100644 index 0000000000..9762242173 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t { + // TODO: (Roma) Code duplication duplication! Remove with templates + // (maybe...)! + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx2_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx2_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + cpu_reducer_2d_t::conf_t reducer_wei_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const int ic_block = jcp_.bcast_block; + const int nb_ic = jcp_.nb_bcast; + const int nb_ic_blocking = jcp_.nb_bcast_blocking; + const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking); + + const int oc_block = jcp_.load_block; + const int nb_oc = jcp_.nb_load; + const int nb_oc_blocking = jcp_.nb_load_blocking; + const int load_work = utils::div_up(nb_oc, nb_oc_blocking); + + const int job_size + = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block; + const int njobs_x = bcast_work; + const int njobs_y = jcp_.ngroups * load_work; + + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = max_threads * job_size * 8; + + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + oc_block, jcp_.ngroups * jcp_.oc / oc_block, + jcp_.mb, max_buffer_size)); + } + + reducer_wei_conf_.init( + reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x, + jcp_.mb * jcp_.nb_reduce, max_buffer_size), + job_size / nb_oc_blocking, nb_oc_blocking, ic_block, + nb_ic * ic_block * oc_block, nb_oc); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx2_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete rtus_driver_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + cpu_reducer_2d_t *reducer_weights_; + cpu_reducer_t *reducer_bias_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp new file mode 100644 index 0000000000..e24770a2da --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp @@ -0,0 +1,1501 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw + + (ki*dilate_w + jj*stride_w - pad_l)); + else + inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w + - pad_l)*ic_blk + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + } +} + +void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw + + (jj * stride_w - pad_l)); + else + inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk + + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = + ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel AVX support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int od = jcp.od; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : ic_blk; + const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w; + + Label init_done, init_first; + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + vmovups(Ymm(ur_w * ii + jj), + make_safe_addr(reg_output, offt, reg_long_offt)); + } + } + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vmovups(Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj)); + } + + L(init_done); + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + Label skip_kh_loop, skip_kd_loop, kd_loop; + if (jcp.ndims == 5) { + push(reg_output); + push(oi_iter); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_input); + + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + cmp(reg_ki, 0); + je(skip_kd_loop, T_NEAR); + } + L(kd_loop); + mov(kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_input, aux_reg_inp_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks, + oc_blocks_tag); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult); + add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + pop(reg_output); + } + + Label regular_store; + + if (jcp.with_eltwise) { + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off + = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + Ymm reg_out = Ymm(ur_w * ii + jj); + vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out); + } + } +} + +inline void jit_avx2_conv_fwd_kernel_f32::solve_common( + int oc_blocks, char oc_blocks_tag) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, + 'l', oc_blocks, oc_blocks_tag); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, + 'l', oc_blocks, oc_blocks_tag); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, + 'm', oc_blocks, oc_blocks_tag); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, + 'r', oc_blocks, oc_blocks_tag); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, + 't', oc_blocks, oc_blocks_tag); // "tail" +} + +void jit_avx2_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + if (jcp.nb_oc > jcp.nb_oc_blocking) { + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + L(exit); + } else if (jcp.nb_oc == jcp.nb_oc_blocking) { + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + } else { + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int simd_w = 8; + const bool flat = jcp.ic < simd_w; + const bool mimo = !flat; + + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off = + one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c); + if (!args_ok) return status::unimplemented; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively + // Thus, we can only assign 14 or 15 YMMs for data storage + const int num_avail_regs = mayiuse(avx2) ? 15 : 14; + if (!mayiuse(avx2)) { + if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) { + // current register assignment requires more YMMs than available + // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad + if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1) + jcp.ur_w -= 1; + else + for (int b = 3; b > 1; b--) + if (jcp.nb_oc % b == 0) { + jcp.nb_oc_blocking = b; + break; + } + } + } + + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +void jit_avx2_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow, + int r_overflow) +{ + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + Label kd_loop, skip_kd_loop; + Label oc_loop, skip_oc_loop; + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(ur_w * ii + jj)); + } + + if (one_of(jcp.ndims, 3, 4)) { + cmp(reg_channel_work, 0); + jle(skip_oc_loop, T_NEAR); + xor_(reg_channel, reg_channel); + + mov(aux_reg_ddst_oc_loop, reg_ddst); + mov(aux_reg_kernel_oc_loop, reg_kernel); + + L(oc_loop); + mov(aux_reg_ddst, aux_reg_ddst_oc_loop); + mov(aux_reg_kernel, aux_reg_kernel_oc_loop); + } + + if (jcp.ndims == 5) { + assert(jcp.nb_oc_blocking == 1); + push(oi_iter); + + mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_ddst); + mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]); + + L(kd_loop); + mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_ddst, aux_reg_dst_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + Label kh_loop, skip_kh_loop; + cmp(kj, 0); + jle(skip_kh_loop, T_NEAR); + L(kh_loop); { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); // 0; + int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) { + + for (int jj = jj_start ; jj < jj_end; jj += stride_w) { + int aux_output_offset + = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2; + vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w), + ptr[aux_reg_ddst + + sizeof(float) * aux_output_offset]); + } + + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset + = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block + + ki * jcp.ic_block * jcp.oc_block + + ofm2 * jcp.ic_block; + vmovups(ymm15, + ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15); + } + } + } + add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block + * ic_block); + sub(aux_reg_ddst, sizeof(float) * ow * oc_block); + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + L(skip_kh_loop); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, + sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + } + + if (one_of(jcp.ndims, 3, 4)) { + int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow + * jcp.oc_block; + int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw + * jcp.ic * jcp.oc_block; + + add(aux_reg_ddst_oc_loop, ddst_oc_shift); + add(aux_reg_kernel_oc_loop, kernel_oc_shift); + + inc(reg_channel); + cmp(reg_channel, reg_channel_work); + jl(oc_loop, T_NEAR); + + L(skip_oc_loop); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + } + + Label no_update_label; + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int ii = 0; ii < nb_ic_block; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(Ymm(15), + make_safe_addr(reg_dsrc, offt, reg_long_offt)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(15)); + + } + } + L(no_update_label); + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt), + Ymm(ur_w * ii + jj)); + } +} + +void jit_avx2_conv_bwd_data_kernel_f32::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]); + + int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block; + int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block; + + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + int r_overflow = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad)) / jcp.stride_w); + int r_overflow1 = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) + n_oi--; + + if (jcp.ur_w == jcp.iw) { + compute_loop(jcp.ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(jcp.ur_w, l_overflow, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } else { + xor_(oi_iter, oi_iter); + if (l_overflow > 0) { + compute_loop(jcp.ur_w, l_overflow, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + } + + if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop; + L(ow_loop); { + compute_loop(jcp.ur_w, 0, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR); + } + } + + if (r_overflow1 > 0 ) { + compute_loop(jcp.ur_w, 0, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + } + + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } + + this->postamble(); +} + +status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + + int ndims = diff_src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int simd_w = 8; + + /* derivatives */ + jcp.idp = jcp.id + 2 * jcp.f_pad; + jcp.ihp = jcp.ih + 2 * jcp.t_pad; + jcp.iwp = jcp.iw + 2 * jcp.l_pad; + jcp.ohp = jcp.oh; /* do we really need */ + jcp.owp = jcp.ow; /* padded output ??? */ + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + /* gemm-based convolution performs better in these cases */ + if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1) + return status::unimplemented; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + if (jcp.oc % jcp.oc_block) return status::unimplemented; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.nb_ic_blocking = 1; + jcp.nb_oc_blocking = 1; + jcp.ur_w = 1; + + if(one_of(ndims, 3, 4) && jcp.ow < 40) + jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2; + + if (ndims == 3) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + + bool args_ok = true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i, + gOIdhw8o8i, OIdhw8o8i) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && jcp.stride_w == jcp.stride_h + && jcp.stride_d == 1 + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.ic % simd_w == 0 + && jcp.oc % simd_w == 0 + && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) return status::unimplemented; + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad; + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + + const int max_regs = 15; /* Maximun number of registers available for + result accumulation and delta dst data. + One additional register is reserved for weights + data. */ + + /* Find the best blocking with maximum number of fma instructions + per ur_w * nb_ic_blocking compute loops. Number of required registers + is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs. + ur_w must be divisible by stride_w */ + if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers + distribution exceeds max_regs */ + return status::unimplemented; + + int best_nfmas = 0; + for (int b = 1; b <= 4; b++) + { + if (jcp.nb_ic % b != 0) + continue; + + for (int u = jcp.stride_w; + u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w; + u += jcp.stride_w) + { + int ur_w = nstl::min(u, jcp.iw); + /* maximum 1 step with l_overflow so far */ + if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) + continue; + int nfmas = utils::div_up(ur_w, jcp.stride_w) * b; + if (nfmas > best_nfmas + || (nfmas == best_nfmas && jcp.ur_w < ur_w)) { + jcp.ur_w = ur_w; + jcp.nb_ic_blocking = b; + best_nfmas = nfmas; + } + } + } + if (best_nfmas == 0) /* can't find appropriate blocking */ + return status::unimplemented; + + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + /* maximum 1 ur_w block with r_overflow so far */ + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + return status::success; +} + +void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +void jit_avx2_conv_bwd_weights_kernel_f32::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + compute_oh_loop_common(); + this->postamble(); +} + +status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + const int simd_w = 8; + + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id + - jcp.f_pad); + if (ndims == 5) + if (jcp.f_pad != 0 || back_pad != 0) + return status::unimplemented; + + const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1); + const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1); + const bool boundaries_ok = true + && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad + && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad; + if (!boundaries_ok) + return status::unimplemented; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && IMPLICATION(mimo, jcp.ic % simd_w == 0) + && jcp.oc % simd_w == 0 + && jcp.kw < 14 + && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */ + && jcp.kh <= jcp.ih /* [bwd_w:r2] */ + && jcp.kd <= jcp.f_pad + jcp.id + && jcp.kd <= jcp.id + && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */ + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + return status::success; +} + +void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_loop; + mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0 + L(kd_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult); + sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block + * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + mov(kj, reg_kh); + Label kh_comeback_loop; + L(kh_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(reg_input, sizeof(float) * jcp.iw * inp_mult); + sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset, + int kernel_offset, int output_offset) +{ + const int kw = jcp.kw; + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + vmovups(Ymm(kw * ic_block_step + 0), + yword[reg_output + + sizeof(float) * i_ur * oc_block + output_offset]); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw; + if (i_iw - pad_l < 0 + || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r) + continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t i_off = (size_t)input_offset + sizeof(float)*( + one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? (i_iw - pad_l) + i_ic + * ((size_t)jcp.id * jcp.ih * jcp.iw) + : (i_iw - pad_l) * ic_block + i_ic); + vbroadcastss(Ymm(kw * ic_block_step + 1), + make_safe_addr(reg_input, i_off, reg_long_offt)); + vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic), + Ymm(kw * ic_block_step + 0), + Ymm(kw * ic_block_step + 1)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(yword[reg_kernel + off], + Ymm(i_kw * ic_block_step + i_ic)); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp() +{ + int ic_block_step; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block; + } else { + ic_block_step = jcp.kw > 7 ? 1 + : jcp.kw > 3 ? 2 + : jcp.kw > 1 ? 4 : 8; + } + + const int max_ur_w = jcp.ow > 56 ? 14 : 28; + + if (jcp.ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + od_step_comeback_pointers(); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } else { + oh_step_comeback_pointers(); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad + = nstl::max(0, + (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, + 0, 0); + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, ic_block); + jl(ic_block_loop, T_NEAR); + } + if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + const int stride_w = jcp.stride_w; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad = jcp.r_pad; + + int ur_w = nstl::min(jcp.ow, max_ur_w); + int ur_w_trips = jcp.ow / ur_w; + int ur_w_tail = jcp.ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block; + + int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + if (jcp.l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, + jcp.l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) + * (ur_w * stride_w - jcp.l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + Label ow_block_loop; + L(ow_block_loop); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_loop, T_NEAR); + } + } + + if (ur_w_tail > 0) + compute_ic_block_step(ur_w_tail, + 0, r_pad, ic_block_step, 0, 0, 0); + + sub(reg_input, sizeof(float) * input_comeback); + sub(reg_output, sizeof(float) * output_comeback); + + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_loop, T_NEAR); + } + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common() +{ + const int icoc_block = jcp.ic_block * jcp.oc_block; + const int t_pad = jcp.t_pad; + const int stride_h = jcp.stride_h; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + int b_pad = jcp.b_pad; + + Label oh_tpad_loop, oh_loop, oh_loop_end; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + if (t_pad > 0) { + assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */ + mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih); + add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block); + + L(oh_tpad_loop); { + compute_oh_step_disp(); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + add(reg_kh, stride_h); + + /* the overlap between input and kernel may not reach kernel size. + * so far we do not support that (until we put constant here) */ + const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */ + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_loop, T_NEAR); + } + + if (t_pad % stride_h != 0) { + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block); + add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult); + } + } + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_loop, T_NEAR); + + mov(reg_kh, jcp.kh); + L(oh_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_loop, T_NEAR); + } + L(oh_loop_end); + if (b_pad > 0) { + Label oh_bpad_loop, oh_bpad_loop_end; + cmp(reg_oj, jcp.oh); + jge(oh_bpad_loop_end, T_NEAR); + + mov(reg_kh, jcp.ih + t_pad); + sub(reg_kh, reg_ih_count); + L(oh_bpad_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_loop_end, T_NEAR); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_loop, T_NEAR); + } + L(oh_bpad_loop_end); + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp new file mode 100644 index 0000000000..412c50c9ee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP +#define JIT_AVX2_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_conv_fwd_kernel_f32: public jit_generator { + jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_avx2_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32) + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t aux_reg_inp_d = r11; + reg64_t aux_reg_ker_d = abi_not_param1; + + reg64_t reg_ki = rsi; + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = r15; + reg64_t reg_long_offt = r15; + Xbyak::Reg32 reg_ci_flag = r13d; + + Xbyak::Ymm ytmp = Xbyak::Ymm(14); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void solve_common(int oc_blocks, char oc_blocks_label); + + void generate(); +}; + +struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32) + + jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t reg_dsrc = rsi; + reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only + reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5 + case only */ + + reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only + reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only + + reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only + reg64_t kj = r11; + reg64_t oi_iter = r12; + reg64_t reg_kh = r14; + reg64_t reg_channel = r13; // used in ndims < 5 case only + reg64_t reg_channel_work = r9; // used in ndims < 5 case only + reg64_t reg_long_offt = r15; + + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32) + + jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_tmp = r11; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t ki = r14; + reg64_t reg_long_offt = r11; + + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset); + inline void compute_oh_step_disp(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_loop_common(); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp new file mode 100644 index 0000000000..13f61e84fe --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp @@ -0,0 +1,410 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx2_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, d, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (pd()->ndims() == 4) \ + ? (f).blk_off(n, c, h, w) \ + : (f).blk_off(n, c, d, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ + (pd()->ndims() == 3) \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : (pd()->ndims() == 4) \ + ? wht_blk_off_(f, g, oc, ic, kh, kw) \ + : wht_blk_off_(f, g, oc, ic, kd, kh, kw) + +void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od + * jcp.oh; + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const int dj = od * jcp.stride_d; + const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); + const int d_b_overflow = nstl::max(jcp.id, dj + + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + + const int id = nstl::max(dj - jcp.f_pad + + div_up(d_t_overflow, + (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0); + + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, id, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wd, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + + const int kd_padding = jcp.kd + - div_up(d_t_overflow, (jcp.dilate_d + 1)) + - div_up(d_b_overflow, (jcp.dilate_d + 1)); + par_conv.kd_padding = nstl::max(0, kd_padding); + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + } + icbb += icb_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +void jit_avx2_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; + int ih_block_size = jcp.ih; + int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; + if (work_amount < (size_t)2 * mkldnn_get_max_threads()) { + ih_block_size = 1; + num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + work_amount *= num_ih_blocks; + } + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + size_t n{0}, g{0}, icbb{0}, ihb{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, + ihb, num_ih_blocks); + for (size_t iwork = start; iwork < end; ++iwork) { + for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) + for (int id = 0; id < jcp.id; ++id) { + auto par_conv = jit_conv_call_s(); + + const int idp = jcp.id + 2 * jcp.f_pad; + const int d_t_overflow = nstl::max(0, + jcp.kd - 1 - id - jcp.f_pad); + const int back_pad = idp - jcp.id - jcp.f_pad; + const int d_b_overflow = nstl::max(0, + jcp.kd - 1 - (jcp.id - 1 - id) - back_pad); + const int od = id + jcp.f_pad - d_b_overflow; + + int ih_start = ihb * ih_block_size; + int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size); + for (int ih = ih_start; ih < ih_end; ++ih) { + + const int i_t_overflow = nstl::max(0, (jcp.kh - 1 + - ih - jcp.t_pad) / jcp.stride_h); + const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih + + ih - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ih) % jcp.stride_h); + int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h; + + par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow; + par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow - i_b_overflow; + par_conv.kw_padding = 0; + + const int k_lo = overflow_kh_lo + + i_b_overflow * jcp.stride_h; + const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h; + + par_conv.src = &diff_src[src_blk_off(diff_src_d, n, + /*jcp.ic == 3 ? 0 :*/ + g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)]; + par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, + n, g * jcp.nb_oc + oc, od, oh, 0)]; + par_conv.filt = &weights[wht_blk_off(weights_d, g, oc, + jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb, + d_b_overflow, k_lo, 0)]; + + par_conv.src_prf = nullptr; + par_conv.dst_prf = nullptr; + par_conv.filt_prf = nullptr; + par_conv.channel = oc; + par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc, + jcp.nb_oc_blocking); + + kernel_->jit_ker(&par_conv); + } + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, + num_ih_blocks); + } + }; + + parallel(0, ker); +} + +void jit_avx2_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + auto ker = [&](int ithr, int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_job_start = rw->balancer().ithr_job_off(ithr); + const int w_njobs = rw->balancer().ithr_njobs(ithr); + + if (w_njobs == 0) return; + + /* reduction dimension */ + int img_od_start{0}, img_od_end{0}, img{0}, od_s{0}; + balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), img_od_start, img_od_end); + + int img_start = img_od_start, img_end = img_od_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + /* jobs */ + int g_start{0}, ocb_start{0}, icb_start{0}; + nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc, icb_start, jcp.nb_ic); + + while (img_start < img_end) { + int g = g_start, ocb = ocb_start, icb = icb_start; + + const int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int idp = jcp.id + jcp.f_pad + jcp.back_pad; + + if (id_s < idp - jcp.back_pad - jcp.kd + 1) + for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + /* TODO: put dw <-- 0 in kernel */ + if (img == img_first) + array_set(rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_, 0, + rw->balancer().job_size_); + + for (int od = od_s; od < od_e; ++od) { + const int id = od * jcp.stride_d; + if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break; + + auto par_conv = jit_conv_call_s(); + par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)]; + par_conv.dst = + &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)]; + par_conv.filt = rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_; + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb, + jcp.nb_ic); + } + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) + d_bias[o] = 0.; + + for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp new file mode 100644 index 0000000000..bb65bce79c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP +#define CPU_JIT_AVX2_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_, + *desc(), src_md(), weights_md(), dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() < 8; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_avx2_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_fwd_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx2_convolution_bwd_data_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_data_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf( + jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + cpu_reducer_t::conf_t reducer_wei_conf_; + + protected: + bool set_default_formats() { + using namespace format_tag; + const bool flat = IC() == 3; + + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + + private: + void init_balancers() { + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = 1<<21; /* just a heuristic */ + + if(with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + + reducer_wei_conf_.init(reduce_balancer_t(max_threads, + jcp_.kd * jcp_.kh * jcp_.kw + * jcp_.ic_block * jcp_.oc_block, + jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc, + jcp_.mb * jcp_.od, max_buffer_size)); + } + }; + + jit_avx2_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , reducer_weights_(nullptr) + , reducer_bias_(nullptr) + { + kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_); + reducer_bias_ = + new cpu_reducer_t(pd()->reducer_bia_conf_); + reducer_weights_ = + new cpu_reducer_t(pd()->reducer_wei_conf_); + } + + ~jit_avx2_convolution_bwd_weights_t() { + delete kernel_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_weights_kernel_f32 *kernel_; + cpu_reducer_t *reducer_weights_, *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp new file mode 100644 index 0000000000..635b83b2bf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp @@ -0,0 +1,1255 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" +#include "cpu_barrier.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_common_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt)); + + if (jcp.ver == ver_4fma) + { + Label bcast_loop; + Label bcast_loop_wraparound; + Label bcast_loop_out; + Label bcast_loop_ur_full; + + cmp(bcast_loop_iter, jcp.ur); + jle(bcast_loop_wraparound, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jg(bcast_loop, T_NEAR); + } + + L(bcast_loop_wraparound); + if (jcp.ur_tail) { + je(bcast_loop_ur_full, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + jmp(bcast_loop_out, T_NEAR); + } + L(bcast_loop_ur_full); + reduce_loop(load_loop_blk, jcp.ur, 0, true); + L(bcast_loop_out); + } + else + { + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } + } +} + +void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load, int i_fma) { + return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step) + + jcp.fma_step * i_load + i_fma); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_out * jcp.oc_block * i_load); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + int offt; + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + offt = (i_reduce == jcp.reduce_loop_unroll) + ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll + : i_ur * jcp.reduce_loop_unroll + i_reduce; + } else { + if (jcp.transpose_src) { + const int reduce_group = i_reduce / 4; + const int reduce_shift = i_reduce % 4; + offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift; + } + else + offt = i_reduce * jcp.ic_block + i_ur; + } + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int offt; + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) + return EVEX_compress_addr(aux_reg_output_data, + (i_load * jcp.bcast_dim + i_ur) * jcp.load_block + * jcp.typesize_out); + else + return ptr[aux_reg_output_data + + (i_load + ? reg_output_stride * i_load + : 0) // TODO: Xbyak should allow 0 scale + + jcp.typesize_out * jcp.load_block * i_ur]; + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_sum) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + for (int i_ur = 0; i_ur < ur; ++i_ur) { + mic_prefetcht1(output_ptr(i_load, i_ur)); + } + } + } + + if (jcp.with_bias + && one_of(jcp.prop_kind, forward_training, forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero, T_NEAR); + + for (int i_load = 0; i_load < load_loop_blk; i_load++) + for (int i_ur = 0; i_ur < ur; ++i_ur) + vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load)); + jmp(init_done, T_NEAR); + } + + L(init_zero); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + L(init_done); + }; + + auto store = [=]() { + Label store_noadd; + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto r = vreg_accum(i_load, i_ur); + vaddps(r, r, output_ptr(i_load, i_ur)); + } + + L(store_noadd); + if (jcp.with_eltwise) { + Label store_noeltwise; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_noeltwise, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_noeltwise); + } + + auto store_output = [=](bool output_is_aligned) { + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + if (output_is_aligned && jcp.use_vmovntps) + vmovntps(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + else + vmovups(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + }; + + Label unaligned_store, end_store; + test(aux_reg_output_data, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + }; + + auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load, + bool last_block, bool wraparound, int reduce_step) + { + bool pf_ker_l1 = true; + bool pf_ker_l2 = wraparound; + int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk; + int i_op = (i_reduce / reduce_step) * ur * load_loop_blk + + i_ur * load_loop_blk + i_load; + + int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0; + int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0; + int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur; + + int pf_inp_ops = n_ops / 2; // # of operations during which to pf input + int pf_inp_trigger; + if (jcp.prop_kind == backward_weights) + pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block); + else + pf_inp_trigger = nstl::max(1, pf_inp_ops / ur); + + int n_other_pf = + load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1); + int n_other_pf_ops = n_ops - pf_inp_ops; + int other_pf_trigger + = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0; + + if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) { + // input prefetches have the highest priority b/c the + // first iteration of the kernel block touches all the + // cache lines + int i_pf = i_op / pf_inp_trigger; + auto pf_reg = wraparound && last_block + ? reg_bcast_data + : (last_block ? aux1_reg_bcast_data + : aux_reg_bcast_data); + int offt = i_pf; + if (jcp.prop_kind == backward_weights) { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.is : jcp.reduce_block); + offt *= jcp.bcast_block; + } else { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.ur : jcp.bcast_dim); + offt *= jcp.reduce_block; + } + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_op >= pf_inp_ops && n_other_pf) { + // remaining prefetches are spread among the rest of the + // operations; prefetches for output take priority + // TODO: spread L2 prefetches among L1 prefetches + i_op -= pf_inp_ops; + if (i_op % other_pf_trigger == 0) { + int i_pf = i_op / (load_loop_blk * other_pf_trigger); + if (i_pf < n_pf_ker_l2) { + int offt = (i_pf + (i_load + 1) * jcp.reduce_dim) + * jcp.load_block; + mic_prefetcht1(ptr[aux_reg_load_data + + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) { + i_pf -= n_pf_ker_l2; + auto pf_reg = last_block ? reg_load_data + : aux_reg_load_data; + int offt = (i_pf + i_load * jcp.reduce_dim + + (last_block + ? (wraparound ? jcp.reduce_dim : 0) + : jcp.reduce_block)) + * jcp.load_block; + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) { + i_pf -= n_pf_ker_l1 + n_pf_ker_l2; + int offt = i_pf * jcp.load_block; + mic_prefetcht0(ptr[aux_reg_output_data + + offt * jcp.typesize_out]); + } + } + } + }; + + auto fma_block = [=](bool last_block) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + + int reduce_step = jcp.fma_step; + + for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + // if transposed input data used and if spatial size is + // not divided by transpose step (4) then for last reduce step + // we should load only needed load_registers data + // and clear remaining + if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block + && i_reduce == jcp.reduce_loop_unroll - reduce_step) { + Label load_all; + Label load_finish; + test(reg_reduce_pos_flag, FLAG_SP_LAST); + jz(load_all, T_NEAR); + + const int n_loads = jcp.is % jcp.fma_step; + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + if (i_fma < n_loads) + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + else + vpxord(vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma)); + } + jmp(load_finish); + + L(load_all); + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + L(load_finish); + } else { + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + } + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + if (jcp.ver == ver_4fma) + v4fmaddps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, false)); + else if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), vreg_bcast); + else + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, true)); + prefetch_callback(ur, i_reduce, i_ur, i_load, + last_block, wraparound, reduce_step); + } + } + } + }; + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx512_common_1x1_conv_kernel::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + sub(rsp, stack_space_needed); + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (one_of(jcp.prop_kind, forward_training, forward_inference)) + mov(reg_relu_ns, reinterpret_cast(&jcp.eltwise.alpha)); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_weights: + for (int i_load = 0; i_load < load_loop_blk; i_load++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 }; + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 }; + + const int size_ur_cases_fma + = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + sizeof(ur_cases_fma_expl_bcast) : + sizeof(ur_cases_fma_embd_bcast); + const int size_ur_cases_4fma = sizeof(ur_cases_4fma); + + const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + ur_cases_fma_expl_bcast : + ur_cases_fma_embd_bcast; + const int *ur_cases = jcp.ver == ver_4fma ? ur_cases_4fma : ur_cases_fma; + const int num_ur_cases = + (jcp.ver == ver_4fma ? size_ur_cases_4fma : size_ur_cases_fma) + / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_common)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, + format_kind::undef, cd.diff_bias_desc.format_kind) + != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c); + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + jcp.transpose_src = false; + + if (everyone_is(data_type::f32, src_d.data_type(), + weights_d.data_type(), dst_d.data_type())) + { + const int is_bwd_d = jcp.prop_kind == backward_data; + format_tag_t wei_tag = with_groups + ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, + gOIhw16i16o, gIOhw16o16i) + : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, + OIhw16i16o, IOhw16o16i); + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) && + ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4 + == 0) { + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops) + && !reduce_src + /* Heuristic condition for relation of src size to oc. Otherwise + the src transposition overhead exceed the benefit from 4fma + */ + && ((jcp.is * jcp.ic) / jcp.oc <= 2048) + && mkldnn_thr_syncable() + ) + { + jcp.transpose_src = true; + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else { + jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma; + jcp.fma_step = 1; + } + jcp.typesize_in = sizeof(prec_traits::type); + jcp.typesize_out = sizeof(prec_traits::type); + } else { + return status::unimplemented; + } + + /* once all the formats are set, check the padding consistency */ + args_ok = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + const int SMALL_SPATIAL = 10; + const int BIG_SPATIAL = 28; + const int BIG_REDUCE_DIM = 1024; + const int BIG_LOAD_DIM = 256; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + int reduce_blocking_max{ 0 }; + + jcp.load_grp_count = 1; + + const int L1_capacity = get_cache_size(1, true) / sizeof(float); + const int L2_size = get_cache_size(2, true) / sizeof(float); + const int L2_capacity = (L2_size * 3) / 4; + + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + } else { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.ic_block; + + jcp.bcast_dim = jcp.os; + } + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + // adjusting registry blocking + int max_regs, min_regs, size_treshold, ur_step; + const int spatial + = (one_of(jcp.prop_kind, forward_training, forward_inference)) ? + jcp.oh : + jcp.ih; + if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) { + max_regs = 9; + min_regs = 6; + size_treshold = 14; + ur_step = 1; + jcp.expl_bcast = true; + + if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM + && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) { + max_regs = 6; + min_regs = 5; + } + } else { + max_regs = jcp.ver == ver_4fma ? 28 : 30; + min_regs = 9; + size_treshold = jcp.ver == ver_4fma ? 28 : 14; + ur_step = jcp.ver == ver_4fma ? 4 : 1; + jcp.expl_bcast = false; + jcp.use_vmovntps = true; + } + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i -= ur_step) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.bcast_block = jcp.ur; + + jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_iter_step = jcp.load_block; + + if (jcp.prop_kind == backward_data) + jcp.loop_order = loop_lbr; + else + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + int nb_load = div_up(jcp.load_dim, jcp.load_block); + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast) { + if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL + && spatial < BIG_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 80); + else if (spatial > SMALL_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 512); + else + reduce_blocking = nstl::min(jcp.reduce_dim, 256); + + if ((jcp.mb > 28 && spatial >= 28) + || (jcp.mb > 112 && spatial >= 17)) + jcp.use_vmovntps = true; + else + jcp.use_vmovntps = false; + } else { + + reduce_blocking = nb_reduce; + if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + else if (spatial > SMALL_SPATIAL + && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 8; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + // Check input data cache aliasing. + // For other ISA constants may be updated. + // 64 * 1024 is chosen due to 1MB L2 16-way cache. + // 7 is empirical value. It is about half of 16. + // So we leave about half of the set for other data - weights, dst + int way_size = (64 * 1024) / jcp.typesize_in; + int max_hits = 7; + if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) { + int nrb = reduce_blocking / simd_w; + int sp = jcp.bcast_dim; + int wl = way_size / simd_w; + for (int start_off = 0; start_off < jcp.ur; start_off++) { + for (int off = start_off, hits = 0; off < sp * nrb; off += wl) { + if (off % sp >= jcp.ur || ++hits < max_hits) + continue; + int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp); + reduce_blocking + = nstl::min(reduce_blocking, max_r_blocking); + break; + } + } + } + + if (reduce_blocking < jcp.reduce_dim) { + jcp.use_vmovntps = false; + if (jcp.prop_kind == backward_data) + jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; + else + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + } + load_blocking = jcp.load_dim; + + int load_size = jcp.load_dim * jcp.reduce_dim; + int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; + + if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads + && nb_load * nb_bcast > nthreads) { + // Some heuristic here + float calc_koef = 0.01, best_cost = FLT_MAX; + int n_lgc = nthreads; + float ratio = (float)load_size / (float)bcast_size; + int best_lgc = ratio > 1 ? n_lgc : 1; + auto calc_job_cost = [&](int lb, int tg, float mem_k) { + int bb_size = jcp.mb * div_up(nb_bcast, tg); + float calc_size = (float)(bb_size * jcp.ur) + * (lb * jcp.load_block) * jcp.reduce_dim; + float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) + * jcp.reduce_dim; + return calc_koef * calc_size + mem_k * mem_size; + }; + for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { + lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; + int min_lb = nb_load / lgc; + int max_lb = div_up(nb_load, lgc); + int min_tg = nthreads / lgc; + int max_tg = div_up(nthreads, lgc); + // Some heuristic here + float mem_koef = (max_tg == 1) ? 1.f : 1.3f; + float job_cost = 0.; + if (nthreads % lgc < nb_load % lgc) { + job_cost = calc_job_cost(max_lb, min_tg, mem_koef); + } else { + auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); + auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); + job_cost = nstl::max(job_cost1, job_cost2); + } + + if (job_cost < best_cost) { + best_lgc = lgc; + best_cost = job_cost; + } + } + jcp.load_grp_count = best_lgc; + load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } else { + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + } + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64 + && load_size >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); + load_blocking = jcp.load_block; + } + + if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim + && jcp.oh * jcp.ow > 64 + && IMPLICATION(reduce_src, jcp.load_dim < 1024)) { + /* Looking for best loading dimension blocking + * to get the best thread and data read/write efficiency + * by finding the optimal 'load_chunk' value + * Example: + * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512 + * the 'best' load_chunk value should be 1 + * TODO: remove heuristic constants in above condition + * TODO: check this blocking for other ISA + */ + float best_eff = -1.f; + int best_lgc = 1; + + for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { + int lgc = div_up(nb_load, load_chunk); + if (lgc > nthreads) + continue; + int thr_per_grp = div_up(nthreads, lgc); + int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp) + * jcp.bcast_block; + int load_per_thr = load_chunk * simd_w; + float data_norm = (bcast_per_thr + load_per_thr) / 2.f; + float data_eff = (bcast_per_thr * load_per_thr) + / (data_norm * data_norm); + float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc) + / div_up(nthreads, lgc); + float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) + / rnd_up(jcp.mb * nb_bcast, thr_per_grp); + float thr_eff = thr_eff_over_grp * thr_eff_in_grp; + float load_eff = (float)nb_load / rnd_up(nb_load, lgc); + float overall_eff = data_eff + thr_eff + load_eff; + if (overall_eff > best_eff) { + best_eff = overall_eff; + best_lgc = lgc; + } + } + jcp.load_grp_count = best_lgc; + load_blocking + = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) + * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + } else if (jcp.prop_kind == backward_weights) { + + jcp.use_vmovntps = false; + if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma) + jcp.use_vmovntps = true; + + if (jcp.transpose_src) + jcp.reduce_dim = jcp.tr_is; + else + jcp.reduce_dim = jcp.is; + + if (jcp.ver == ver_4fma) { + // reduce_block should be divided by fma_step + jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4); + } else { + jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); + if (jcp.reduce_dim % jcp.reduce_block != 0) + jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); + if (jcp.reduce_block > 256) { + jcp.reduce_block = 1; + } + + } + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) { + // if reduce_block is big then generated JIT code may be big + // for small values of ur because reduce_loop_unroll = reduce_block + jcp.ur = jcp.bcast_block / 2; + jcp.expl_bcast = true; + } else { + jcp.ur = jcp.bcast_block; + jcp.expl_bcast = false; + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in; + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = + jcp.oc_block * jcp.ic_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = + jcp.oc_block * jcp.ur * jcp.typesize_out; + jcp.bcast_loop_bcast_step = + jcp.ic_block * jcp.reduce_dim * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in; + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + balance(jcp, nthreads); + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + load_blocking = best_divider(load_blocking, 16, load_blocking, false); + load_blocking *= jcp.load_block; + + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + int min_bcast_blocking = 5; + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + bcast_blocking = best_divider( + bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + // for reduction balance + if (jcp.ver == ver_avx512_core) { + int max_reduce_blocking + = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim); + int min_reduce_blocking = nstl::min( + L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); + reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, + max_reduce_blocking, true); + reduce_blocking + = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), + jcp.reduce_block); + } else { + int max_reduce_blocking = L2_capacity + / ((bcast_blocking + load_blocking) * jcp.reduce_block); + max_reduce_blocking = nstl::min(max_reduce_blocking, + (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block); + + int num_jobs = div_up(jcp.load_dim, load_blocking) + * div_up(jcp.bcast_dim, bcast_blocking); + int threads_per_job = nstl::max(1, nthreads / num_jobs); + reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block); + reduce_blocking = div_up(reduce_blocking, threads_per_job); + + reduce_blocking = best_divider(reduce_blocking, + max_reduce_blocking - 2, max_reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + if (jcp.ver == ver_4fma) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + } + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx512_common_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.with_bias + && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); + + if (jcp.prop_kind == backward_weights) { + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic; + scratchpad.book(key_conv_wei_reduction, + jcp.typesize_out * wei_size * (jcp.nthr_mb - 1)); + } + + if (jcp.transpose_src) { + const size_t tr_src_size = + (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is; + scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size); + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * jcp.nthr); + } +} + +void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp, + int nthreads) +{ + // initialize jcp reduction threading properties + jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; + if (nthreads < jcp.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + const int nb_load = div_up(jcp.load_dim, jcp.load_block); + const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + jcp.nthr_g = jcp.ngroups; + const int nthr = nthreads / jcp.nthr_g; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level + * optimizer tries to minimize memory consumption. few notes: (n1) + * unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + int bcast_koeff = 1; + int load_koeff = 1; + int output_koeff = 12; + if (jcp.transpose_src) { + bcast_koeff = 5; + load_koeff = 1; + output_koeff = 8; + } + return 0 + + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block + / jcp.stride_h / jcp.stride_w /* (n1) */ + + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block + + (size_t)output_koeff /* (n2) */ + * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block + * jcp.oc_block; + }; + + int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; + auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); + for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); + for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); + auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + jcp.nthr_mb = nthr_mb; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) + jcp.nthr_mb = nstl::min(jcp.mb, nthreads); + + jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; + assert(jcp.nthr <= nthreads); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp new file mode 100644 index 0000000000..d2ae017943 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp @@ -0,0 +1,108 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP +#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_1x1_conv_kernel : public jit_generator { + jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_common_1x1_conv_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel) + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + + reg64_t reg_bcast_data = r8; + reg64_t reg_load_data = r10; + reg64_t reg_output_data = r9; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = aux_reg_load_data; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reg_load_loop_work = rsi; + reg64_t reg_reduce_loop_work = r11; + reg64_t bcast_loop_iter = rdx; + reg64_t reduce_loop_iter = abi_param1; + reg64_t reg_reduce_pos_flag = rax; + reg64_t reg_output_stride = r13; + reg64_t reg_bias_data = r12; + reg64_t reg_relu_ns = r13; + reg64_t reg_bcast_loop_work = aux1_reg_bcast_data; + + Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + int bcast_loop_work_offt = 0; + int stack_space_needed = 16; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + static void balance(jit_1x1_conv_conf_t &jcp, int nthreads); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp new file mode 100644 index 0000000000..54d58c8a39 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp @@ -0,0 +1,816 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_common_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + + +namespace { +template +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const int grp_count = nstl::min(nx_divider, nthr); + const int grp_size_big = nthr / grp_count + 1; + const int grp_size_small = nthr / grp_count; + const int n_grp_big = nthr % grp_count; + const int threads_in_big_groups = n_grp_big * grp_size_big; + + const int ithr_bound_distance = ithr - threads_in_big_groups; + T grp, grp_ithr, grp_nthr; + if (ithr_bound_distance < 0) { // ithr in first groups + grp = ithr / grp_size_big; + grp_ithr = ithr % grp_size_big; + grp_nthr = grp_size_big; + } else { // ithr in last groups + grp = n_grp_big + ithr_bound_distance / grp_size_small; + grp_ithr = ithr_bound_distance % grp_size_small; + grp_nthr = grp_size_small; + } + + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} +/* convolution forward */ + +template +void jit_avx512_common_1x1_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const auto &jcp = kernel_->jcp; + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template +void jit_avx512_common_1x1_convolution_fwd_t:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + + const int ndims = src_d.ndims(); + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + }; + + auto init_reduce = [&](int icb) + { + const int nb_ic_blocking_step = + nstl::min(icb + nb_ic_blocking, nb_ic) - icb; + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking_step >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking_step * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, + int ih, int iw) + { + + const int _ocb = g * nb_oc + ocb; + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + p.bias_data = &bias[_ocb * jcp.oc_block]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + + +template struct jit_avx512_common_1x1_convolution_fwd_t; +/* convolution backward wtr data */ + +template +void jit_avx512_common_1x1_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).template get( + key_conv_rtus_space); + + const int ndims = diff_src_d.ndims(); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + parallel(0, [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, icb_start, icb_end, jcp.load_grp_count); + + bool reduce_outer = (jcp.loop_order == loop_rbl + || jcp.loop_order == loop_rlb); + int nboc_outer = reduce_outer ? nb_oc : 1; + int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; + + int nboc_inner = reduce_outer ? 1 : nb_oc; + int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; + + for (int ocb_outer = 0; ocb_outer < nboc_outer; + ocb_outer += ocb_outer_step) { + size_t cur_ocb_outer = + nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer; + + int load_step = 0; + for (int icb = icb_start; icb < icb_end; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, + icb_end * jcp.ic_block, load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = bcast_start; iwork < bcast_end; + iwork += bcast_step) + { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb_inner = 0; ocb_inner < nboc_inner; + ocb_inner += ocb_inner_step) { + int cur_ocb_inner = + nstl::min(ocb_inner + ocb_inner_step, nboc_inner) - + ocb_inner; + + int ocb = reduce_outer ? ocb_outer : ocb_inner; + int nb_oc_blocking_step = reduce_outer + ? cur_ocb_outer : cur_ocb_inner; + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, + jcp.oc, nb_oc_blocking_step * jcp.oc_block); + + kernel_->jit_ker(&p); + } + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + } + }); +} + +template struct jit_avx512_common_1x1_convolution_bwd_data_t; + +/* convolution backward wtr weights */ + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +jit_avx512_common_1x1_convolution_bwd_weights_t :: + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) + , trans_kernel_(nullptr), rtus_driver_(nullptr) +{ + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + acc_ker_ = new cpu_accumulator_1d_t(); + reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); + init_rtus_driver(this); + + const auto &jcp = kernel_->jcp; + + if (jcp.transpose_src) { + auto tp = jit_transpose4x16_src_t(); + tp.src_pf0_distance = 4; + tp.tr_src_pf0_distance = 0; + tp.src_pf1 = true; + tp.tr_src_pf1 = false; + trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp); + } +} + +void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + const auto scratchpad = this->scratchpad(ctx); + + auto rtus_space = scratchpad.get(key_conv_rtus_space); + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + auto wei_reduction = scratchpad.get(key_conv_wei_reduction); + + /* prepare src transposition barriers */ + auto tr_src = scratchpad.get(key_conv_tr_src); + auto tr_src_bctx = scratchpad.get( + key_conv_tr_src_bctx); + if (jcp.transpose_src) { + for (int i = 0; i < jcp.nthr; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + + const int ndims = src_d.ndims(); + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic; + + simple_barrier::ctx_t reduction_barrier; + simple_barrier::ctx_init(&reduction_barrier); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + + const int sp_nb = jcp.nb_reduce; + const int mb_sp_work = jcp.mb * sp_nb; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + // TODO: use memory descriptor with the same fmt as src + // (or use a macro :)) + auto tr_src_off = [&](int img, int icb, int is) { + const size_t tr_chn_size = jcp.tr_is * jcp.ic_block; + const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups; + return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block; + }; + + auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size, + int g_start, int g_work, int ic_b_start, int ic_b_work, + int ithr, int nthr, int first_ic_b) + { + const int work_amount = g_work * ic_b_work; + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int g{ 0 }, ic_b{ 0 }; + nd_iterator_init(start, g, g_work, ic_b, ic_b_work); + g += g_start; + const int ic_b_tr = g * nb_ic + first_ic_b + ic_b; + ic_b += ic_b_start; + + const int _ic = g * nb_ic + ic_b; + + const int is = sp_b_start * jcp.reduce_block; + const int ih = is / jcp.iw; + const int iw = is % jcp.iw; + + const int src1_off = data_blk_off(src_d, img, _ic, ih, iw); + data_t *src1 = (data_t *)&src[src1_off]; + data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.is * jcp.ic_block; + const int tr_src_stride = jcp.tr_is * jcp.ic_block; + + const int my_work = end - start; + for (int iwork = 0; iwork < my_work; iwork++) { + auto par_trans = jit_src_transpose_s(); + assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4); + par_trans.size = sp_size; + par_trans.src = src1; + par_trans.tr_src = tr_src1; + par_trans.src_prf = src1 + 64 * 16; + par_trans.tr_src_prf = tr_src1 + 80 * 16; + trans_kernel_->jit_ker(&par_trans); + + src1 += src_stride; + tr_src1 += tr_src_stride; + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1)); + + const int ithr_ic_b = ithr % jcp.nthr_ic_b; + const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; + const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; + const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / + jcp.nthr_g; + + const int ithr_but_oc + = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b; + + /* reduction dimension */ + int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 }; + if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) { + // it's preferable to parallelize by mb if possible + int img_start{ 0 }, img_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end); + mb_sp_b_start = img_start * sp_nb; + mb_sp_b_end = img_end * sp_nb; + } + else { + balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, + mb_sp_b_end); + } + + /* independent dimensions */ + int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 }; + int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 }; + + balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); + balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, + oc_b_end); + balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, + ic_b_end); + + const int g_work = g_end - g_start; + const int oc_b_work = oc_b_end - oc_b_start; + const int ic_b_work = ic_b_end - ic_b_start; + + data_t *diff_wei = ithr_mb == 0 + ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size; + + int sp_b_step = 0; + for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; + mb_sp_b += sp_b_step) { + int img{ 0 }, sp_b{ 0 }; + nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); + sp_b_step = step(jcp.nb_reduce_blocking, + nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), + jcp.nb_reduce_blocking_max); + + for (int g = g_start; g < g_end; ++g) { + int load_step = 0; + int bcast_step = 0; + for (int ic_b = ic_b_start; ic_b < ic_b_end; + ic_b += bcast_step) { + bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, + jcp.nb_bcast_blocking_max); + if (jcp.transpose_src) { + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + const int sp_size + = nstl::min(sp_b_step * jcp.reduce_block, + jcp.is - sp_b * jcp.reduce_block); + uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b, + bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start); + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + } + + for (int oc_b = oc_b_start; oc_b < oc_b_end; + oc_b += load_step) { + load_step = step(nb_oc_blocking, oc_b_end - oc_b, + jcp.nb_load_blocking_max); + const int _ic_b = g * nb_ic + ic_b; + const int _ic_b_tr = g * nb_ic + ic_b_start; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + store_to = diff_wei + off; + + const data_t *diff_src = jcp.transpose_src ? + &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] : + &src[src_d.blk_off(img, _ic_b)]; + + int sp_b_end = sp_b + sp_b_step; + const data_t *pdiff_dst + = &diff_dst[diff_dst_d.blk_off(img, _oc_b)]; + const data_t *local_src = diff_src; + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + p.output_stride + = jcp.ic * jcp.oc_block * jcp.typesize_out; + + p.load_dim = load_step * jcp.oc_block; + + p.bcast_dim = bcast_step * jcp.ic_block; + rp.icb = bcast_step; + p.output_data = store_to; + + p.reduce_dim = sp_b_step * jcp.reduce_block; + rp.os = p.reduce_dim; + + p.first_last_flag = 0 + | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0) + | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); + + int sp = sp_b * jcp.reduce_block; + p.load_data = pdiff_dst + sp * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + sp * jcp.ic_block; + + if (ndims == 3) + rp.src = local_src + iw + * src_d.blocking_desc().strides[2]; + else + rp.src = local_src + ih + * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = local_src + sp * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + } + + /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ + if (jcp.nthr_mb > 1) { + simple_barrier::barrier(&reduction_barrier, jcp.nthr); + const int work = g_work * oc_b_work * ic_b_work; + int start{ 0 }, end{ 0 }; + balance211(work, jcp.nthr_mb, ithr_mb, start, end); + if (start == end) + return; + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + int w = start; + int sub_g_start{ 0 }, sub_oc_b_start{ 0 }, + sub_ic_b_start{ 0 }; + nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, + oc_b_work, sub_ic_b_start, ic_b_work); + while (w < end) { + const int g = g_start + sub_g_start; + const int oc_b = oc_b_start + sub_oc_b_start; + const int ic_b = ic_b_start + sub_ic_b_start; + + const int acc_size + = nstl::min(end - w, ic_b_work - sub_ic_b_start) + * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + data_t *d = diff_weights + off; + data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, g_work, + sub_oc_b_start, oc_b_work, sub_ic_b_start, + ic_b_work); + } + } + } + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) + return; + + /* reduction dimension */ + int img_start{ 0 }, img_end{ 0 }; + + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{ 0 }, ocb_start{ 0 }; + nd_iterator_init( + b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_load + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + ker(ithr, jcp.nthr); + if (pd()->with_bias()) + ker_bias(ithr, jcp.nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp new file mode 100644 index 0000000000..2e9fda76d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx512_common_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_fwd_t); + + status_t init() { + using namespace utils; + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr(), + mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = + new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_common_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_fwd_f32_t + = jit_avx512_common_1x1_convolution_fwd_t; + +template +struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_common_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + + private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_bwd_data_f32_t + = jit_avx512_common_1x1_convolution_bwd_data_t; + +struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t +{ + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, + jcp_.mb, max_buffer_size)); + } + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx512_common_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + delete reducer_bias_; + delete rtus_driver_; + delete trans_kernel_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + + private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + cpu_accumulator_1d_t *acc_ker_; + cpu_reducer_t *reducer_bias_; + jit_transpose4x16_src *trans_kernel_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp new file mode 100644 index 0000000000..235fb02fef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp @@ -0,0 +1,4539 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" + +#include "jit_avx512_common_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) +#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + +constexpr auto small_spatial = 14; +unsigned int L1_cache_size = get_cache_size(1, true); + +inline void pick_loop_order(jit_conv_conf_t &jcp) { + using namespace prop_kind; + assert(one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)); + auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow; + auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh; + + // ow-threading is currently implemented for forward only + // TODO: single code for fwd and bwd after ow-thr for bwd + // meaningless switch was removed + if (jcp.prop_kind == backward_data) { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cgn : loop_gnc; + } else { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cwgn : loop_gncw; + } +} + +inline bool is_1stconv(const jit_conv_conf_t &jcp) { + if (mayiuse(avx512_core)) + return (jcp.ic < 16 && jcp.ngroups == 1); + else + return one_of(jcp.ic, 1, 3); +} + +inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { + return (jcp.nb_ow > 1); +} + +inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) { + return (jcp.ver == ver_4fma && is_ow_threading_on(jcp)); +} + +} + +template +void _jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + if (!is_owb_prefetching(jcp)) { + size_t aux_output_offset = get_output_offset(j, k); + mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) +{ + Label no_update_label, store_label, eltwise_label; + + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + if (jcp.with_bias) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + } + + if (!jcp.with_sum) { + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + } + + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = get_output_offset(j, k); + vaddps(vmm, + make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt)); + } + + if (!jcp.with_sum) { + jmp(eltwise_label, T_NEAR); + } else { + cmp(reg_channel, 0); + jne(eltwise_label, T_NEAR); + } + + L(no_update_label); + if (jcp.with_bias) { + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + int bias_offset = jcp.typesize_out * k * jcp.oc_block; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset)); + } + mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64)); + } + } + + L(eltwise_label); + if (jcp.with_eltwise) { + cmp(reg_channel, jcp.nb_ic - 1); + jl(store_label, T_NEAR); + + if (ur_w == jcp.ur_w) { + eltwise_injector_->compute_vector_range(0, + jcp.nb_oc_blocking * jcp.ur_w); + } else { + for (int k = 0; k < jcp.nb_oc_blocking; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); + } + } + + L(store_label); + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = (size_t)typesize * + ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; + vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset, + reg_out_long_offt), vmm); + if (!is_owb_prefetching(jcp)) + mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ + assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0); + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + Label kh_label, kd_label; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in + * ((size_t)(kw + ur_w * stride_w - pad_l) + + (size_t)ic_block * iw * ih * jcp.id); + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + + L(kd_label); + } + mov(reg_kj, reg_kh); + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + L(kh_label); + for (int ki = 0; ki < kw; ki += 4) { + for (int ic = 0; ic < ic_block; ic++) { + for (int i = 0; i < 4; i++) { + int aux_ker_offset + = jcp.typesize_in + * ((ki + i) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block); + if (ki + i < kw) + vmovups(vmm_ker(i), + EVEX_compress_addr(aux_reg_ker, aux_ker_offset)); + else + vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i)); + } + + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + + for (int j = j_start, prf_count=0; j < j_end; j++) { + size_t aux_input_offset = (size_t)jcp.typesize_in + * ((size_t)(ki + j * stride_w + - pad_l) + (size_t)ic * iw * ih * jcp.id); + v4fmaddps(vmm_out(j, 0), vmm_ker(0), + EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset, + reg_long_offt)); + if (ki + prf_count < kw && prf_count < 4 + && ((ki < 2 && j % 4) || j % 2)) { + int aux_ker_offset = jcp.typesize_in + * ((ki + prf_count) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_ker_offset)); + prf_count++; + } + if (ki == 0 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf, + aux_input_offset, reg_long_offt)); + } + if (ki == 1 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset+jcp.typesize_in * iw, reg_long_offt)); + } + } + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block); + add(aux_reg_inp, jcp.typesize_in * iw); + add(aux_reg_inp_prf, jcp.typesize_in * iw); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block); + add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block; + int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block; + + bool check_last_kh = (jcp.kh > 3); + bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28); + + int oi_ipref_t0 = get_ow_start(0, pad_l); + int ow_end_ipref = get_ow_end(ur_w, 0, pad_r); + + assert(jcp.oc % jcp.nb_oc_blocking == 0); + + auto kernel_offset = [=](int ocb, int ic, int ki) { + int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int ic_offset = ic * jcp.oc_block; + return typesize * (blk_offset + ic_offset); + }; + auto kernel_loads = [=](int ki, int ic, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, ic + ii, ki); + vmovups(vmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) { + int aux_inp_offset + = typesize + * ((oi_ipref_t0 * stride_w - pad_l) * ic_block + + (jcp.dilate_h + 1) * jcp.iw * ic_block); + prefetcht0(EVEX_compress_addr(aux_reg_inp, + aux_inp_offset)); + oi_ipref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_ker_prf, reg_ker_prf); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + int kw = jcp.kw; + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1 + && ki == kw - 1 && (ic + 4) == ic_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + + if (oi % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, + prf_count_t0 + ic + 4 + - ic_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, ic + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, ic + + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_inp_next_kh(ki, 2, prf_count_t0, + prf_count_t1); + } + + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + if (oi % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(0, + prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), + prf_count_t1 = 0, prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + + if (!is_owb_prefetching(jcp)) { + if ((oi % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } else { + if (!(ki == 0 && ic == 0) + && !(ki == kw-1 && ic == 0) && + (oi % 2) && (prf_count_t1 < 4) + ) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker, kernel_offset(kk, + ic + 4 + prf_count_t0, ki))); + prf_count_t0++; + } + } + if (!is_owb_prefetching(jcp)) { + if (pref_current_inp) { + if (ki == 0 && ic == 0 && kk == 0) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + shift_input_ptr)); + } else { + if (ki == 1 && ic == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp_prf, aux_input_offset)); + } + } else { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift + = jcp.typesize_in * ur_w * stride_w * inp_mult; + bool kk_pref_slot = kk ? oi % 2 : !(oi % 2); + if (ki == 0 && ic == 0 && kk_pref_slot) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + + if (ki == kw - 1 && ic == 0 && kk_pref_slot) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + } + } + } + } + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + add(aux_reg_ker_prf, shift_kernel_ptr); + add(aux_reg_inp_prf, shift_input_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, + int pad_l, int pad_r) +{ + bool prf_ker = true; + bool prf_inp = true; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + int id = jcp.id; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = ic_block * nb_oc_block * kw; + int num_ker_prfs = prf_ker ? num_ker_loads : 0; + int num_inp_prfs = prf_inp ? + ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) : + 0; + if (jcp.is_1stconv && prf_inp) { + num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block; + } + int num_prfs = num_ker_prfs + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w; + int prf_inst_spacing + = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1; + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id; + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int ic = 0; ic < ic_block; ic++) { + int aux_kernel_offset = 0; + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + aux_kernel_offset = get_kernel_offset(ki, ic, 0, i); + vmovups(vmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + aux_kernel_offset + = get_kernel_offset(ki, ic, 0, load_offset); + vmovups(vmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth); + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + for (int j = j_start; j < j_end; j++) { + size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l); + auto addr = EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true); + vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr); + int fma_idx = step * ur_w + j; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (prf_ker && !ker_prf_inserted + && ker_prfs < num_ker_prfs) { + int ker_prf_offset + = jcp.typesize_in * ker_prfs * jcp.oc_block; + mic_prefetcht2(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else if (prf_inp) { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + size_t inp_prf_stride = nstl::max(kw, stride_w); + size_t inp_prf_offset; + if (!jcp.is_1stconv) { + inp_prf_offset + = ic_block * jcp.typesize_in + * ((inp_prf_idx / kw) + * inp_prf_stride + + (inp_prf_idx % kw)); + } else { + size_t ic_prf_stride = + (size_t)jcp.typesize_in * iw * ih * id; + size_t iw_prf_stride + = jcp.typesize_in * jcp.simd_w; + inp_prf_offset = ((inp_prf_idx / ic_block) + * iw_prf_stride + + (inp_prf_idx % ic_block) + * ic_prf_stride); + } + mic_prefetcht0(EVEX_compress_addr_safe( + aux_reg_inp_prf, inp_prf_offset, + reg_long_offt)); + } + } + } + } + step++; + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block); + if (prf_ker) + add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block); + add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + if (prf_inp) + add(aux_reg_inp_prf, + jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, + int pad_l, int pad_r) +{ + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block + * jcp.ic_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * inp_mul; + + + auto input_offset = [=](int oi, int ic, int ki) { + return (size_t)jcp.typesize_in + * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * inp_mul + (size_t)ic + * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id)); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + for (int ic = 0; ic < ic_block; ic++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t aux_input_offset = input_offset(jj, ic, ki); + vbroadcastss(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt)); + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = jcp.typesize_in + * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block + * oc_block + ki * ic_block * oc_block + ic * oc_block); + if (jj_end - jj_start > 0) + vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj++) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(vmm_out(jj, ii), + vmm_inp(jj, nb_oc_block), vmm_wei); + else { + size_t aux_input_offset = input_offset(jj, ic, ki); + vfmadd231ps(vmm_out(jj, ii), vmm_wei, + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true)); + } + } + } + } + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w, + int pad_l, int pad_r) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + } + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + + if (jcp.ver == ver_4fma) + if(jcp.is_1stconv) + compute_loop_4fma_1st(ur_w, pad_l, pad_r); + else + compute_loop_4fma(ur_w, pad_l, pad_r); + else if (jcp.ver == ver_fma) + if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast) + || mayiuse(avx512_mic)) + compute_loop_fma(ur_w, pad_l, pad_r); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1) + compute_loop_fma(ur_w, pad_l, pad_r); + else + compute_loop_fma_core(ur_w, pad_l, pad_r); + else + assert(!"unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +template +void _jit_avx512_common_conv_fwd_kernel::generate() +{ + int iw = jcp.iw; + int ow = jcp.ow; + int ow_block = jcp.ow_block; + int nb_ow = jcp.nb_ow; + int kw = jcp.kw; + int l_pad = jcp.l_pad; + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; + int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult; + int out_shift = jcp.typesize_out * ur_w * jcp.oc_block; + + preamble(); + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + int r_pad = nstl::max( + 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1)); + int n_oi = ow / ur_w; + int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + + if (!is_ow_threading_on(jcp)) { + // ow is being processed as a whole - with left and right paddings + if (r_pad1 > 0) n_oi--; + + if (ow == ur_w) { + mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]); + mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]); + compute_loop(ur_w, l_pad, r_pad); + } else { + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + if (n_oi == 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, r_pad1); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } else { + xor_(reg_oi, reg_oi); + if (l_pad > 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + inc(reg_oi); + } + if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; + Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; + + assert(ow_block % ur_w == 0); + int n_oi_not_last_ow_block = ow_block / ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + + int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w; + + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + + mov(reg_oi, n_oi_first_ow_block); + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + + if (l_pad > 0) { + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + + L(middle_ow_blocks_label); + + if (l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + add(reg_inp_prf, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + L(oi_loop_start_label); + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + jmp(oi_loop_start_label, T_NEAR); + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + // that is last block + if (!last_ow_block_padded) { + jmp(tail_label, T_NEAR); + } + + // last oi block with right padding + L(last_oi_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + L(tail_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, int nthreads) +{ + using namespace prop_kind; + + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const int regs = 28; + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.is_1stconv = is_1stconv(jcp); + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + const int full_simd_w = cpu_isa_traits::vlen / sizeof(float); + jcp.simd_w = full_simd_w; + bool ok_to_try_xmm = true + && mayiuse(avx512_core) + && src_d.data_type() == data_type::f32 + && !jcp.is_1stconv + && !ok_to_pad_channels + && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0) + && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0); + if (ok_to_try_xmm) + jcp.simd_w = 4; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + jcp.aligned_threads = 0; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0; + if (!args_ok) + return status::unimplemented; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto src_tag = jcp.is_1stconv + ? pick(ndims - 3, ncw, nchw, ncdhw) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c)); + auto dst_tag = (jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o) + : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o)); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, x)); + } + + if (mayiuse(avx512_common) && + src_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && dst_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + + if (jcp.is_1stconv) { + // TODO: fix & remove constraints below + bool not_for_4fma + = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad), + nstl::max(jcp.kw, jcp.kh) < 7); + bool is_dilated + = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w); + if (one_of(true, not_for_4fma, is_dilated)) + jcp.ver = ver_fma; + if (jcp.ver == ver_4fma) { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o) + : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o)); + } else { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o) + : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o)); + } + } + } else { + return status::unimplemented; + } + + if (weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.is_1stconv) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + // avx512_core guard - just to avoid possible regression for other archs + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + for (int ur_w = regs; ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } + } + // TODO (Tanya): currently applied to Segnet convolutions only. + // Need to try for other topologies + if (jcp.ow > 150 && jcp.ur_w < regs/2) + jcp.ur_w = regs; + + int n_oi = (jcp.ow / jcp.ur_w); + int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + if (jcp.l_pad > 0 && r_pad > 0) + n_oi--; + + bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0 + && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw; + int mult = 1; + if (jcp.l_pad > 0) mult += 1; + if (r_pad > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) { + jcp.ur_w = ur_w; + break; + } + } + } + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off + = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ? + jcp.ic : + 1; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + auto is_ow_threading_applicable = [=]() { + return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4) + && IMPLICATION(mayiuse(avx512_mic), + jcp.ver == ver_4fma + && IMPLICATION(jcp.mb != 1, + jcp.ih == 1 && jcp.kh == 1))); + }; + + if (jcp.ver == ver_4fma && !jcp.is_1stconv) { + if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8 + && jcp.oh <= 8 && jcp.ow == jcp.oh) + || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) { + if (jcp.nb_oc % 2 == 0) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } else { + for (int i = jcp.nb_oc; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) { + jcp.nb_oc_blocking = i; + break; + } + } + if (jcp.ver == ver_4fma && is_ow_threading_applicable()) { + if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow + && jcp.ow != 2 * jcp.ur_w) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } + } + + jcp.ow_block = jcp.ow; + + auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) { + int nb_ow = div_up(jcp.ow, ow_block); + int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking); + int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow; + float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block); + float thr_eff = disbalance * (float)work_amount + / rnd_up(work_amount, nthreads); + return thr_eff; + }; + + auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) { + int res_ow_block = jcp.ow; + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + if (!is_ow_threading_applicable()) + return res_ow_block; + + int L2_part = (get_cache_size(2) * 7 / 8) / typesize; + if (jcp.ver == ver_4fma) + L2_part /= 2; + int size_src_chunk = jcp.ic_block * ur_w * jcp.kh; + int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w; + int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block + * jcp.kw * jcp.kh; + int nurw_cache = (L2_part - 2 * size_wei_chunk) + / (2 * size_dst_chunk + 2 * size_src_chunk); + // current design of generate() requires ow_block >= 2 * ur_w + int ow_block_cache = ur_w * nstl::max(2, nurw_cache); + + int ow_block_thr = ow_block_cache; + eff = get_thr_eff(nb_oc_blocking, ow_block_thr); + + int max_nb_ow = div_up(jcp.ow, 2 * ur_w); + int start_nb_ow = div_up(jcp.ow, ow_block_thr); + for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); + float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f; + if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + float thr_eff = get_thr_eff(nb_oc_blocking, ow_block); + float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f; + if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) { + ow_block_thr = ow_block; + eff = thr_eff; + } + eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f; + if (eff > eff_threshold) + break; + } + res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr)); + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + return res_ow_block; + }; + + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_oc_blocking = 2; + unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w) + * jcp.ic_block * jcp.kh * jcp.kd; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block + * try_nb_oc_blocking; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_oc_blocking * jcp.kd; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + + bool embd_bcast_condition = true + && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size) + && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192) + && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512); + + if (jcp.mb == 1) { + unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h) + * div_up(jcp.iw, jcp.stride_w) * jcp.ic; + unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; + + // Estimate whether we need to limit the number of threads + // and calculate this number. Includes some heuristic. + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh; + int job_size_min = work_amount / nthreads; + int job_size_max = div_up(work_amount, nthreads); + int ch_max = rnd_up(jcp.oh, job_size_max); + int ch_min = (job_size_min == 0) + ? jcp.oh + : rnd_up(jcp.oh, job_size_min); + bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2 + && (jcp.oh != 8 || ch_max / jcp.oh > 1); + bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2 + && (jcp.oh != 8 || ch_min / jcp.oh > 1); + bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1) + || nthreads > oc_chunks; + if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1 + && wei_size / inp_size > 24 + && (not_aligned_max || not_aligned_min) + && eligible_case) { + // Try to find nthreads > mkldnn_get_max_threads() / 2 such + // that oc_chunks is a multiple of nthreads, or nthreads is a + // multiple of oc_chunks. Otherwise, keep default value. + // TODO: implement a task-based alternative without throttling. + jcp.aligned_threads = nthreads; + for (int i = nthreads; i > nthreads / 2; i--) { + if (oc_chunks % i == 0 || i % oc_chunks == 0) { + jcp.aligned_threads = i; + break; + } + } + } + } + + if (jcp.kw > 3 + || (jcp.stride_w == 1 && jcp.stride_h == 1 + && embd_bcast_condition) + || ((jcp.stride_w != 1 || jcp.stride_h != 1) + && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10) + && embd_bcast_condition))) + || (jcp.mb == 1 + && (jcp.ur_w >= jcp.ow || jcp.is_1stconv + || (jcp.ow <= 147 && jcp.oc <= 96)))) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.ow, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3 + && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0 + && IMPLICATION(jcp.is_1stconv, jcp.mb == 1) + && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) { + jcp.nb_oc_blocking = try_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_ic_blocking = 1; + if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) { + float best_thr_eff = 0.f; + int best_nb_oc_blocking = 1; + for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) { + if (jcp.nb_oc % i == 0) { + float thr_eff; + int ur_w = nstl::min(jcp.ow, 31 / (i + 1)); + get_ow_block(i, ur_w, thr_eff); + if (thr_eff > 1.05f * best_thr_eff) { + best_nb_oc_blocking = i; + best_thr_eff = thr_eff; + } + } + } + jcp.nb_oc_blocking = best_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } + } + + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.l_pad <= jcp.ur_w + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_ic_L2 = jcp.nb_ic; + + float thr_eff; + jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff); + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + const int L2_size = get_cache_size(2, true) / sizeof(float); + // Source and output data needs to fit in L2, + // leaving some space for weights and prefetching. + int h_L2 = int(((0.6f * L2_size) / jcp.simd_w + - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw) + / (jcp.stride_h * jcp.iw + jcp.ow)); + jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2)); + + if (jcp.ver == ver_4fma) { + if (!is_ow_threading_on(jcp)) { + for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic; + divf++) { + size_t l2_src + = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; + size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking + * jcp.oh * jcp.od; + size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block + * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.oh == 7) { + jcp.nb_ic_L2 = 1; + break; + } + temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf + : jcp.nb_ic_L2); + } else { + jcp.nb_ic_L2 = temp_nb; + break; + } + } + } else if (jcp.ic > 64) { + jcp.nb_ic_L2 = 2; /* according to performance data*/ + } + } + + return status::success; +} + +void jit_avx512_common_conv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + vpxord(zmm, zmm, zmm); + size_t aux_src_offset + = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) + * jcp.ic_block; + mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) +{ + Label no_update_label; + + mov(reg_channel, ptr[param + GET_OFF(channel)]); + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt)); + } + } + + L(no_update_label); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt), zmm); + mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( + int ur_w, int l_overflow, int r_overflow) +{ + int ow = jcp.ow; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * ow * oc_block; + int ii_dpref_t0 = get_iw_start(0, l_overflow); + int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow); + + bool check_last_kh = (jcp.kh > 3); + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + auto kernel_loads = [=](int ki, int oc, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, oc + ii, ki); + vmovups(zmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) { + int aux_dst_offset = typesize * ((ii_dpref_t0 + + jcp.l_pad) * oc_block + jcp.ow * oc_block); + prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + ii_dpref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + align(16); + L(kh_label); + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1 + && ki == kw - 1 && (oc + 4) == oc_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + + if (ii % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, prf_count_t0 + + oc + 4 - oc_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, oc + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1); + } + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if (ii % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(0, prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + kernel_loads(ki, oc, kk); + + for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if ((ii % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + oc + prf_count_t1, ki))); + prf_count_t1++; + } + if ( ki == 1 && oc == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_dst_prf, aux_dst_offset)); + } + } + } + + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + add(aux_reg_ker_prf, shift_ker_ptr); + sub(aux_reg_dst_prf, shift_dst_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( + int ur_w, int l_overflow, int r_overflow) +{ + Label kh_label, kd_label; + int kw = jcp.kw; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int l_pad = jcp.l_pad; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = oc_block * kw; + int num_inp_prfs = ur_w * nstl::min(kw, stride_w) + + nstl::max(0, kw - stride_w); + int num_prfs = num_ker_loads + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w / stride_w; + int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs); + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + L(kh_label); { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int oc = 0; oc < oc_block; oc++) { + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + int aux_kernel_offset = typesize * ((oc + i) * oc_block + + ki * ic_block * oc_block); + vmovups(zmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + int aux_kernel_offset = typesize * ((oc + load_offset) + * oc_block + ki * ic_block * oc_block); + vmovups(zmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + auto zmm_kernel = zmm_ker(step % ker_pipeline_depth); + + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + assert(stride_w != 1 + || jj_start == nstl::max(0, + l_overflow - (kw - 1 - ki) * dilate_w)); + assert(stride_w != 1 + || jj_end == ur_w - nstl::max(0, + r_overflow - ki * dilate_w)); + + for (int jj = jj_start; jj < jj_end; jj += stride_w) { + assert((jj + l_pad - ki * dilate_w) % stride_w == 0); + int aux_dst_offset = typesize * + (((jj + l_pad - ki * dilate_w) + / stride_w) * jcp.oc_block + oc); + vfmadd231ps(zmm_out(jj, 0), zmm_kernel, + EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true)); + + int fma_idx = (step * ur_w + jj) / stride_w; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (!ker_prf_inserted && ker_prfs < num_ker_loads) { + int ker_prf_offset = typesize + * ker_prfs * jcp.oc_block; + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + int inp_prf_offset + = ic_block * typesize + * ((inp_prf_idx / kw) * kw + + (inp_prf_idx % kw)); + mic_prefetcht0(EVEX_compress_addr( + aux_reg_dst_prf, inp_prf_offset)); + } + } + } + } + step++; + } + } + + add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block); + add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + sub(aux_reg_dst_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + } + + if (jcp.ndims == 5) + { + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( + int ur_w, int l_overflow, int r_overflow) +{ + int kw = jcp.kw; + int ow = jcp.ow; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + Label kh_label, kd_label; + + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block; + + auto output_offset = [=](int oi, int oc, int ki) { + return typesize * + (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc); + }; + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + for (int oc = 0; oc < oc_block; oc++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_output_offset = output_offset(jj, oc, ki); + vbroadcastss(zmm_inp(jj, nb_ic_block), + ptr[aux_reg_dst + aux_output_offset]); + } + } + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, oc, ki); + if (jj_end - jj_start > 0) + vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_out(jj, ii), + zmm_inp(jj, nb_ic_block), zmm_wei); + else + vfmadd231ps(zmm_out(jj, ii), zmm_wei, + EVEX_compress_addr(aux_reg_dst, + output_offset(jj, oc, ki), true)); + } + } + } + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( + int ur_w, int l_overflow, int r_overflow) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + mov(reg_kj, ptr[param + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + + if (jcp.ver == ver_4fma) + compute_loop_4fma(ur_w, l_overflow, r_overflow); + else if (jcp.ver == ver_fma) + if (mayiuse(avx512_mic)) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + compute_loop_fma_core(ur_w, l_overflow, r_overflow); + else + assert("!unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::generate() +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ur_w = jcp.ur_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block; + int src_shift = jcp.typesize_out * ur_w * oc_block; + + preamble(); + + mov(reg_src, ptr[param + GET_OFF(src)]); + mov(reg_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + + mov(reg_kh, ptr[param + GET_OFF(kh_padding)]); + mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]); + mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]); + mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]); + + int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w); + int r_overflow = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad)) / stride_w); + int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w); + + int n_oi = iw / ur_w; + if (r_overflow1 > 0) n_oi--; + + if (ur_w == iw) { + compute_loop(ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(ur_w, l_overflow, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + if (ur_w_tail != 0) + compute_loop(ur_w_tail, 0, r_overflow); + } else { + xor_(reg_oi, reg_oi); + if (l_overflow > 0) { + compute_loop(ur_w, l_overflow, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + } + if ((l_overflow <= 0 && n_oi > 0) + || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); { + compute_loop(ur_w, 0, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + compute_loop(ur_w, 0, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + } + if (ur_w_tail != 0) { + compute_loop(ur_w_tail, 0, r_overflow); + } + } + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( + jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx512_common)) return status::unimplemented; + + jcp = zero(); + + jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + int ndims = diff_src_d.ndims(); + + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + if ((jcp.dilate_w != 0 && jcp.stride_w != 1) + || (jcp.dilate_d != 0 && jcp.stride_d != 1) + || (jcp.dilate_h != 0 && jcp.stride_h != 1)) + return status::unimplemented; + + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1); + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.aligned_threads = 0; + + jcp.is_1stconv = false; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && diff_src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) + : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i); + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_w = jcp.stride_w; + + int regs = 28; + if (jcp.iw <= regs) + jcp.ur_w = jcp.iw; + else { + for (int ur_w = regs; ur_w > 0; --ur_w) + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - jcp.l_pad) / jcp.stride_w); + int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w); + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) n_oi--; + + if (mayiuse(avx512_common) + && diff_dst_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && diff_src_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops) + && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && jcp.ver != ver_fma) + return status::unimplemented; + + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (jcp.ver == ver_4fma) { + if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) { + jcp.nb_ic_blocking = 2; + } else { + for (int i = jcp.nb_ic; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + } + } + + jcp.loop_order = loop_gnc; + + bool large_code_size = (jcp.ur_w != jcp.ow) + && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1)) + && (r_overflow1 > 0) && (l_overflow > 0); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw; + int mult = 1; + if (l_overflow > 0) mult += 1; + if (r_overflow1 > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2 + < max_code_size) { + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + } + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_ic_blocking = 2; + unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block + * try_nb_ic_blocking * jcp.kh; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_ic_blocking; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8) + || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13)) + || ker_total_size > L1_cache_size ))) + || jcp.stride_h > 1 || jcp.stride_d > 1) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.iw, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size + && jcp.ow > 8)) && jcp.stride_h == 1) + if (jcp.nb_ic % try_nb_ic_blocking == 0) { + jcp.nb_ic_blocking = try_nb_ic_blocking; + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_oc_blocking = 1; + jcp.nb_ic_blocking = 4; + if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic; + if (jcp.nb_ic % jcp.nb_ic_blocking != 0) + for (int i = jcp.nb_ic_blocking; i > 0; i--) + if (jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + if (l_overflow * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_oc_L2 = jcp.nb_oc; + if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) { + for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc; + divf++) { + size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih + * jcp.id; + size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; + size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh + * jcp.kd * jcp.nb_ic_blocking * temp_nb; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.ih == 7) { + jcp.nb_oc_L2 = 1; + break; + } + temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf + : jcp.nb_oc_L2); + } else { + jcp.nb_oc_L2 = temp_nb; + break; + } + } + } + + args_ok = true + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + return status::success; +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28; + +void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_label; + + /* 'depth' loop count bound by 'kd_work_size' */ + mov(kj, reg_kd_count); + L(kd_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + Label kh_comeback_label, kd_comeback_label; + mov(kj, reg_kh); + L(kh_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(Zmm(i_kw * ic_block_step + i_ic), + EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block + + i_ic) * jcp.oc_block + kernel_offset)); + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + if (i_ur == 0) { + vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 0) + * oc_block + output_offset)); + if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block + + output_offset)); + if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block + + output_offset)); + if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + } else if (i_ur + 3 < ur_w) + vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1); + if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + + (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + const size_t i_offset = (size_t)input_offset + + (size_t)typesize * (jcp.ver == ver_4fma + ? (i_iw - pad_l + i_ic * jcp.tr_iw) + : (jcp.is_1stconv + ? (i_iw - pad_l) + (size_t)i_ic + * ((size_t)jcp.ih*jcp.iw*jcp.id) + : (i_iw - pad_l) * ic_block + i_ic)); + vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic), + Zmm(kw * ic_block_step + i_ur % 4), + EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt, + true)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(EVEX_compress_addr(reg_kernel, typesize + * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset), + Zmm(i_kw * ic_block_step + i_ic)); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + // TODO: add prefetches to fma version as well + + assert(jcp.ver == ver_4fma); + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + auto zmm_ker = [=](int i_kw, int i_ic) { + return Zmm(i_kw * ic_block_step + i_ic); + }; + + auto ker_addr = [=](int i_kw, int i_ic) { + size_t local_offset + = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block; + return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); + }; + + auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) { + int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1); + int local_offset = jcp.typesize_in * (i_iw + i_ic * stride); + return EVEX_compress_addr(reg_input, + local_offset + input_offset + extra_offset); + }; + + auto zmm_out = [=](int i_iw) { + // TODO: move reg calc to global member funcs + const int out_zmm_base_idx = 28; + return Zmm(out_zmm_base_idx + i_iw % 4); + }; + + auto out_addr = [=](int i_ur) { + return EVEX_compress_addr(reg_output, + jcp.typesize_in * i_ur * oc_block + output_offset); + }; + + auto pf_callback = [=](int i_ur, int i_kw, int i_ic) { + assert(i_ur % 4 == 0); + if (i_ur == 0) + prefetcht1(ker_addr(i_kw, i_ic)); + if (i_ur + 4 >= ur_w) + prefetcht0(ker_addr(i_kw, i_ic)); + + const ptrdiff_t next_input_block_offset + = jcp.typesize_in * ic_block_step * jcp.tr_iw; + if (i_ur % 16 == 4 && i_kw == 0) { + if (i_ur + 16 < ur_w) + prefetcht0(inp_addr(i_ur + 16, i_ic)); + else + prefetcht0(inp_addr(0, i_ic, next_input_block_offset)); + } + if (i_ur % 16 == 4 && i_kw == 1) { + if (input_wraparound) + prefetcht1(inp_addr(i_ur, i_ic, -input_offset)); + else + prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset)); + } + }; + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto zmm = zmm_ker(i_kw, i_ic); + vpxord(zmm, zmm, zmm); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur += 4) { + + for (int i = 0; i < 4; i++) { + auto zmm = zmm_out(i_ur + i); + if (i_ur + i < ur_w) + vmovups(zmm, out_addr(i_ur + i)); + else + vpxord(zmm, zmm, zmm); + prefetcht0(out_addr(i_ur + i + 4)); + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + int i_iw = i_ur + i_kw; + v4fmaddps(zmm_ker(i_kw, i_ic), + zmm_out(i_ur), inp_addr(i_iw, i_ic)); + pf_callback(i_ur, i_kw, i_ic); + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto addr = ker_addr(i_kw, i_ic); + auto zmm = zmm_ker(i_kw, i_ic); + vaddps(zmm, zmm, addr); + vmovups(addr, zmm); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + if (jcp.ver == ver_4fma) + compute_ic_block_step_4fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else if (jcp.ver == ver_fma) + compute_ic_block_step_fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else + assert(!"unknown convolution version"); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow_icblock( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + Label kh_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) { + const int input_offset = jcp.typesize_in + * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic); + compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step, + input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0, + i_b_ic + ic_block_step >= jcp.ic_block); + } + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, kd_label; + + UNUSED(max_ur_w); + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + + int r_pad = nstl::max(0, + (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + xor_(b_ic, b_ic); + L(ic_block_label); { + compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, + 0, 0, 0); + size_t inp_icblk_stride = jcp.is_1stconv + ? (size_t)jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, ow_block_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad; + + int ur_w = nstl::min(ow, max_ur_w); + int ur_w_trips = ow / ur_w; + int ur_w_tail = ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) + || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + + int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block; + int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); { + xor_(b_ic, b_ic); + L(ic_block_label); { + if (l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad) + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + L(ow_block_label); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_label, T_NEAR); + } + } + + if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad, + ic_block_step, 0, 0, 0); + + sub(reg_input, jcp.typesize_in * input_comeback); + sub(reg_output, jcp.typesize_in * output_comeback); + int inp_icblk_stride = jcp.is_1stconv + ? jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_disp() +{ + int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2); + if (jcp.is_1stconv) { + bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0); + ic_block_step + = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1; + } + + bool too_large_to_unroll + = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1) + && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1); + + int ow = jcp.ow; + if (jcp.ndims == 5) { + /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of + * 'movs' must be guaranteed. */ + mov(ki, reg_kd_count); + push(reg_kd_count); + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) + compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w); + else if (ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + pop(reg_kd_count); + od_step_comeback_pointers(); + } else { + oh_step_comeback_pointers(); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel() +{ + Label skip_zeroing, zeroing_loop; + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jz(skip_zeroing, T_NEAR); + + Zmm zero = Zmm(0); + vpxord(zero, zero, zero); + xor_(reg_tmp, reg_tmp); + L(zeroing_loop); { + assert(jcp.oc_block * jcp.typesize_out + == cpu_isa_traits::vlen); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block + * jcp.typesize_out], zero); + add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd + * jcp.typesize_out); + jnz(zeroing_loop); + } + + L(skip_zeroing); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel() +{ + Label skip_bias, bias_loop, skip_load_bias; + + mov(reg_tmp, ptr[param + GET_OFF(flags)]); + test(reg_tmp,reg_tmp); + jne(skip_bias, T_NEAR); + + mov(reg_bias, ptr[param + GET_OFF(bias)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + vpxord(Zmm(1), Zmm(1), Zmm(1)); + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jne(skip_load_bias, T_NEAR); + vmovups(Zmm(1), ptr[reg_bias]); + + L(skip_load_bias); + + mov(reg_oi, ptr[param + GET_OFF(d_worksize)]); + sub(reg_oi, ptr[param + GET_OFF(d_index)]); + mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out); + imul(reg_oi, reg_tmp); + + xor_(reg_tmp, reg_tmp); + L(bias_loop); { + vmovups(Zmm(0), ptr[reg_output + reg_tmp]); + vaddps(Zmm(1), Zmm(1), Zmm(0)); + add(reg_tmp, jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, reg_oi); + jl(bias_loop); + } + vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1)); + + L(skip_bias); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_loop_common() +{ + int b_pad = jcp.b_pad; + int t_pad = jcp.t_pad; + bool is_dilated = jcp.dilate_h != 0; + int dilate_h = jcp.dilate_h + 1; + int stride_h = jcp.stride_h; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label, + oh_bpad_label, oh_bpad_label_end, od_label, od_label_end, + oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end; + + int ow = jcp.ow; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + /* Compute 'top' edge */ + if (t_pad > 0) { + const int kh_range = 1 + (jcp.kh - 1) * dilate_h; + const int overflow + = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); + const int underflow = div_up(t_pad, dilate_h); + const int initial_inp_ker_overlap = jcp.kh - overflow - underflow; + mov(reg_kh, initial_inp_ker_overlap); + add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block + * jcp.oc_block); + // generate loop to process kernel while it remains within t_pad + ih + if (kh_range < t_pad + jcp.ih) { + if (is_dilated) { + const int tail = t_pad % dilate_h; + const int shift = tail == 0 ? 0 : dilate_h - tail; + mov(reg_tmp, shift); + if (tail != 0) + add(reg_input, jcp.typesize_in * shift * iw * inp_mult); + } + L(oh_tpad_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_shift, T_NEAR); + // unshift input as new kernel element enters + sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult); + xor_(reg_tmp, reg_tmp); + } + // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_kh, stride_h); + if (is_dilated) { + jmp(oh_dilate_label_noshift, T_NEAR); + L(oh_dilate_label_shift); + // shift input as old kernel element progresses + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + L(oh_dilate_label_noshift); + } + inc(reg_oj); + add(reg_ih_count, stride_h); + + // final number of kernel elements that overlap with input + const int final_inp_ker_overlap + = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h)); + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_label, T_NEAR); + } + } + // need second loop to process kernel if it is larger than the input + // (does not apply to dilations as they must have unit stride) + if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : + t_pad % stride_h)) { + assert(!is_dilated); + mov(reg_kh, jcp.ih); + L(oh_tpad_tail_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h)); + jl(oh_tpad_tail_label, T_NEAR); + } + } + // correct any excess shifts to kernel and input + // (does not apply to dilations as they must have unit stride, + // kernel must fit inside input, and padding is smaller than input) + if (t_pad <= jcp.oh * stride_h) { + // kernel has moved beyond padding (adjust for stride effects) + if (t_pad % stride_h != 0) { + assert(!is_dilated); + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult); + } + } else { + // kernel still overlaps padding (complete reset) + assert(!is_dilated); + sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h) + * jcp.kw * jcp.ic_block * jcp.oc_block); + } + } + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_label, T_NEAR); + + /* Compute middle block(s) */ + mov(reg_kh, jcp.kh); + L(oh_label); { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_label, T_NEAR); + } + L(oh_label_end); + + /* Compute bottom edge */ + if (b_pad > 0) { + cmp(reg_oj, jcp.oh); + jge(oh_bpad_label_end, T_NEAR); + + if (is_dilated) { + mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations + mov(reg_tmp, 0); + } else { + mov(reg_kh, jcp.ihp - b_pad); + sub(reg_kh, reg_ih_count); + } + L(oh_bpad_label); + { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_end, T_NEAR); + xor_(reg_tmp, reg_tmp); + } + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_label_end, T_NEAR); + if (is_dilated) + L(oh_dilate_label_end); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_label, T_NEAR); + } + L(oh_bpad_label_end); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() { + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + const int input_backpad_overlap + = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); + + const size_t filter_shift + = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block; + const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult; + const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block; + + Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, + backpad_end_label, backpad_label; + + if (jcp.with_bias) bias_kernel(); + + /* initially offset 'kd' by f_pad */ + add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); + + mov(reg_input_d, ptr[param + GET_OFF(src)]); + mov(reg_output_d, ptr[param + GET_OFF(dst)]); + mov(reg_d_index, ptr[param + GET_OFF(d_index)]); + mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); + + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jge(loop_end_label, T_NEAR); + + L(d_loop_label); + + mov(reg_input, reg_input_d); + mov(reg_output, reg_output_d); + + push(reg_input_d); + push(reg_output_d); + push(reg_d_index); + + compute_oh_loop_common(); + + pop(reg_d_index); + pop(reg_output_d); + pop(reg_input_d); + + /* Compute 'front' edge */ + if (jcp.f_pad > 0) { + + /* Check if within fpad region */ + cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); + jge(fpad_end_label, T_NEAR); + + /* Fpad steps */ + sub(reg_kernel, filter_shift * jcp.stride_d); + add(reg_kd_count, jcp.stride_d); + + /* Final number of kernel elements that overlap with input */ + const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id); + cmp(reg_kd_count, inp_ker_overlap); + jl(common_block_label, T_NEAR); + + /* Correct any excess shifts to kernel and input */ + if (jcp.f_pad <= jcp.od * jcp.stride_d) { + /* Filter has moved beyond padding (adjust for stride effects) */ + if (jcp.f_pad % jcp.stride_d != 0) { + int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; + add(reg_kernel, filter_shift * inp_corr); + add(reg_input_d, input_shift * inp_corr); + } + } else { + /* Filter still overlaps padding (complete reset) */ + sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); + } + + /* Apply correction */ + mov(reg_kd_count, jcp.kd); + jmp(common_block_label); + + L(fpad_end_label); + } + + /* Compute bottom edge */ + if (jcp.back_pad > 0) { + + /* Check if within back_pad region */ + cmp(reg_d_index, input_backpad_overlap - 1); + jl(backpad_end_label, T_NEAR); + jg(backpad_label, T_NEAR); + + /* Execute overlap correction between the filter and the initial + * back_pad region. */ + mov(reg_kd_count, + jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d); + jmp(backpad_end_label, T_NEAR); + + L(backpad_label); + sub(reg_kd_count, jcp.stride_d); + cmp(reg_kd_count, 0); + jle(loop_end_label, T_NEAR); + + L(backpad_end_label); + } + + /* Compute middle block */ + add(reg_input_d, input_shift * jcp.stride_d); + + /* Execute common block and loop */ + L(common_block_label); + add(reg_output_d, output_shift); + inc(reg_d_index); + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jl(d_loop_label, T_NEAR); + + L(loop_end_label); +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() { + // FIXME: use register mapping from the class declaration + bool ok = jcp.ver == ver_4fma + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && everyone_is(1, jcp.stride_h, jcp.stride_w); + if (!ok) return false; + if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2) + return false; + + // General code layout: + // + // Blocking over OH -- top level + // (Reduces L2 pressure; not very useful right now) + // Loop over all KHxKW kernel -- emit_kh_kw_loop() + // Loop over OH block -- emit_h_loop() + // Loop over OW blocks -- emit_fma_block() + // (Supports both fully unrolled and partially unrolled versions to + // reduce code size) + // Loop over OW block -- emit_fma_step() + + int max_working_set_size = 128 * 1024; + int pad_ow = jcp.ow; + + int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in; + int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in; + int row_size = inp_row_size + out_row_size; + + int h_block_size = jcp.oh; + int working_set_size = row_size * h_block_size; + + if (working_set_size > max_working_set_size) { + int opt_working_set_size = 48 * 1024; + assert(opt_working_set_size < max_working_set_size); + + while (working_set_size > opt_working_set_size) { + for (int i = 2; i <= h_block_size; i++) + if (i == h_block_size) + h_block_size = h_block_size / 2; + else if (h_block_size % i == 0) { + h_block_size = h_block_size / i; + break; + } + working_set_size = row_size * h_block_size; + + if (h_block_size == 1 && working_set_size > opt_working_set_size) + return false; + } + } + + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below) + if (h_block_size < nstl::max(1, jcp.t_pad) + || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size + : jcp.oh % h_block_size)) + return false; + + // check that we can use simple arithmetic for prefetch address + // calculations + // TODO: we need some traits for this check (Roma) + int cache_line_size = 64; + assert(jcp.ic_block * typesize == 64); + assert(jcp.oc_block * typesize == 64); + + int num_inp_l2_pfs = jcp.tr_iw * h_block_size; + int avg_h_loop_len = h_block_size; + int num_inp_l2_pfs_per_fma_block + = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + int num_out_l2_pfs = pad_ow * h_block_size; + int num_out_l2_pfs_per_fma_block + = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + + Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors + Reg64 reg_kh = rax; + Reg64 reg_kw = rbx; + Reg64 reg_tmp = abi_not_param1; + Reg32 reg_tmp_w = reg_tmp.cvt32(); + Reg64 reg_ohs = rdx; + Reg64 reg_ihs = rsi; + Reg64 reg_h = r8; + Reg64 reg_i = r9; + Reg64 reg_j = r10; + + Reg64 reg_inp = r13; + Reg64 reg_out = r14; + Reg64 reg_ker = r15; + + Reg64 reg_inp_pf_l1 = rbp; + + Reg64 reg_inp_pf_l2 = r11; + Reg64 reg_out_pf_l2 = r12; + + Xmm reg_inp_pf_save = xmm17; + Xmm reg_out_pf_save = xmm18; + + Reg64 reg_inp_save = abi_param1; + Reg64 reg_out_save = reg_tmp; + + auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); }; + auto zmm_ker = [&](int ic1) { return Zmm(ic1); }; + auto inp_addr = [&](int oi, int ic1) { + return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in]; + }; + auto out_addr = [&](int oi, int oj = 0) { + assert(jcp.ver == ver_4fma); + return ptr[reg_out + + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in]; + }; + auto ker_addr = [&](int ic1) { + return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out]; + }; + + auto emit_block = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row) + { + // TODO: add an fma version (Roma) + auto pad_ow = jcp.ow; + + int ow4u = rnd_up(pad_ow, 4); + int def_step_size = 16; + + bool has_w_tail = (pad_ow % def_step_size != 0 + || pad_ow % 4 != 0); + bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail; + + auto emit_step = [&](int ur_ow, + int num_inp_l1_pfs_per_fma_step, + int num_inp_l2_pfs_per_fma_step, + int num_out_l2_pfs_per_fma_step, bool is_w_tail) + { + bool block_wraparound = is_w_tail && is_last_row; + + assert(ur_ow % 4 == 0); + int tail_size = ow4u % ur_ow; + int this_ur_ow + = (is_w_tail && tail_size) ? tail_size : ur_ow; + int ow_last_chunk4 = pad_ow % 4; + int ow_zero_tail4 = ow_last_chunk4 + ? 4 - ow_last_chunk4 : 0; + + auto emit_out_pf = [&](int oi) { +#if 1 + if (oi + def_step_size < ur_ow || !block_wraparound) + mic_prefetcht0(ptr[reg_out + + ((def_step_size + oi) + * jcp.oc_block * jcp.typesize_in)]); + else { + assert(block_wraparound); + assert(oi + def_step_size >= ur_ow); + mic_prefetcht0(ptr[reg_out_save + + ((oi + def_step_size - ur_ow) + * jcp.oc_block * jcp.typesize_in)]); + } +#else + // XXX: This is an alternative prefetching strategy that + // always prefetches the next row. Keeping it here for + // future experiments (Roma) + if (!block_wraparound) + mic_prefetcht0(ptr[reg_out + + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]); + else + mic_prefetcht0(ptr[reg_out + reg_ohs + - ((h_block_size - 1) * jcp.ow + - oi) * jcp.oc_block * jcp.typesize_in]); +#endif + if (oi < num_out_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_out_pf_l2 + + oi * jcp.oc_block * jcp.typesize_in]); + }; + + auto emit_inp_pf = [&](int oi4, int ic1) { + int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block; + int num_pf_slots = jcp.ic_block * ur_ow / 4; + + int num_pfs = num_inp_l1_pfs_per_fma_step + + num_inp_l2_pfs_per_fma_step; + int pf_freq = nstl::max(1, num_pf_slots / num_pfs); + + if (pf_slot_idx % pf_freq) + return; + + int pf_idx = pf_slot_idx / pf_freq; + + if (pf_idx < num_inp_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_inp_pf_l2 + + pf_idx * jcp.ic_block * jcp.typesize_in]); + else { + pf_idx -= num_inp_l2_pfs_per_fma_step; + // prefetch the 'tail' of the cache line because most of + // the accesses are not aligned + mic_prefetcht0(ptr[reg_inp_pf_l1 + + pf_idx * jcp.ic_block * jcp.typesize_in + + cache_line_size - jcp.typesize_in]); + } + }; + + auto numloads = 4; + + int steps = this_ur_ow; + for (int oi4 = 0; oi4 < steps; oi4 += numloads) { + for (int oi1 = 0; oi1 < numloads; oi1++) { + int oi = oi4 + oi1; + if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) { + vmovups(zmm_out(oi), out_addr(oi)); + emit_out_pf(oi); + } else { + auto zmm = zmm_out(oi); + vpxord(zmm, zmm, zmm); + } + } + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + if (jcp.ver == ver_4fma) { + v4fmaddps(zmm_ker(ic1), + zmm_out(oi4), inp_addr(oi4, ic1)); + } else { + assert(!"unknown convolution version"); + } + emit_inp_pf(oi4, ic1); + } + } + }; + + // Input is transposed and padded but we only access about jcp.iw + // elements so use that to compute the # of cache lines in each 'row' + int num_inp_l1_pfs + = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block; + + if (full_w_unroll) { + emit_step(ow4u, num_inp_l1_pfs, + num_inp_l2_pfs_per_fma_block, + num_out_l2_pfs_per_fma_block, true); + add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size); + add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size); + } else { + Label w_loop; + int num_w_iters = pad_ow / def_step_size; + int num_w_iters_full = num_w_iters + has_w_tail; + int num_inp_l1_pfs_per_fma_step + = div_up(num_inp_l1_pfs, num_w_iters_full); + int num_inp_l2_pfs_per_fma_step + = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full); + int num_out_l2_pfs_per_fma_step + = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full); + mov(reg_i, num_w_iters); + L(w_loop); { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, false); + add(reg_inp, def_step_size * jcp.typesize_in); + add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in); + add(reg_inp_pf_l1, + num_inp_l1_pfs_per_fma_step * cache_line_size); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + sub(reg_i, 1); + jnz(w_loop); + } + if (has_w_tail) { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, true); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + } + // reset reg_inp and reg_out because emit_h_loop expects + // unmodified pointers + int w_offset = num_w_iters * def_step_size; + sub(reg_inp, w_offset * jcp.typesize_in); + sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in); + } + }; + + auto emit_h_loop = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter) + { + Label h_loop, skip_h_loop; + mov(reg_j, 1); + cmp(reg_j, reg_h); + je(skip_h_loop, T_NEAR); + L(h_loop); { + + lea(reg_inp_pf_l1, + ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]); + emit_block(h_block_size, + is_last_block, is_last_kh_kw_iter, false); + + add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in); + add(reg_j, 1); + cmp(reg_j, reg_h); + jb(h_loop); + } + + L(skip_h_loop); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + mic_prefetcht0(ker_addr(ic1)); + + lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]); + emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true); + }; + + auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block, + int h_block_size) + { + xor_(reg_kh, reg_kh); + Label kh_loop, kh_loop_end; + + int last_oh_block_size + = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size); + int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size; + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size + int ih_block_size = oh_block_size - 1 + jcp.kh + - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad; + + L(kh_loop); { + // determine starting indices for this block + if (is_first_block) { + xor_(reg_tmp, reg_tmp); + mov(reg_ohs, jcp.t_pad); + sub(reg_ohs, reg_kh); + cmovb(reg_ohs, reg_tmp); + + mov(reg_ihs, reg_ohs); + sub(reg_ihs, jcp.t_pad); + add(reg_ihs, reg_kh); + } else { + xor_(reg_ohs, reg_ohs); + mov(reg_ihs, reg_kh); + } + + // determine effective size of block based on padding + mov(reg_tmp, oh_block_size); + sub(reg_tmp, reg_ohs); + mov(reg_h, ih_block_size); + sub(reg_h, reg_ihs); + cmp(reg_tmp, reg_h); + cmovb(reg_h, reg_tmp); + + Label kh_loop_work; + cmp(reg_h, 0); + jg(kh_loop_work, T_NEAR); + + // empty h loop for this jcp.kh: + // - set the output to 0 if necessary + // - move ker pt + // - jump to the end + sub(reg_h, 1); + Label skip_ker_zeroing; + + // The reg_ker ptr has highest bit set if the output needs to be + // zeroed. Those who have byte-aligned their data will suffer the + // consiquences :( + // TODO: move the flag to a mask register? (Roma) + test(reg_ker, 1); + jz(skip_ker_zeroing, T_NEAR); + + Label zeroing_loop; + vpxord(zmm0, zmm0, zmm0); + and_(reg_ker, ~1); // temporarily clear the zeroing flag + mov(reg_tmp, jcp.kw); + L(zeroing_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ker_addr(ic1), zmm0); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out); + sub(reg_tmp, 1); + jnz(zeroing_loop, T_NEAR); + } + // restore the zeroing flag (it will be cleared after the end of + // emit_kh_kw_loop, but we may need it until then) + or_(reg_ker, 1); + jmp(kh_loop_end, T_NEAR); + + L(skip_ker_zeroing); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw + * jcp.typesize_out); + jmp(kh_loop_end, T_NEAR); + + L(kh_loop_work); + + mul_by_const(reg_ihs, reg_tmp, + jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + mul_by_const(reg_ohs, reg_tmp, + pad_ow * jcp.oc_block * jcp.typesize_in); + + add(reg_inp, reg_ihs); + add(reg_out, reg_ohs); + + Label kw_loop; + xor_(reg_kw, reg_kw); + L(kw_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vpxord(zmm, zmm, zmm); + mic_prefetcht1(ker_addr(ic1)); + } + + mov(reg_out_save, reg_out); + mov(reg_inp_save, reg_inp); + lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]); + +#if 0 + // XXX: Generate code with special prefetches when switching + // blocks or at the end of the last block. Disabled to reduce + // code size and because there's no performance benefit (Roma) + Label regular_h_loop, end_h_loop; + cmp(reg_kw, jcp.kw - 1); + jne(regular_h_loop, T_NEAR); + cmp(reg_kh, jcp.kh - 1); + jne(regular_h_loop, T_NEAR); + + emit_h_loop(oh_block_size, is_last_block, true); + jmp(end_h_loop, T_NEAR); + + L(regular_h_loop); + emit_h_loop(oh_block_size, is_last_block, false); + + L(end_h_loop); +#else + emit_h_loop(oh_block_size, is_last_block, false); +#endif + + mov(reg_out, reg_out_save); + mov(reg_inp, reg_inp_save); + + Label do_store; + // The reg_ker ptr has highest bit set if the output needs to + // be zeroed. Those who have byte-aligned their data will + // suffer the consiquences :( + mov(reg_tmp, reg_ker); + and_(reg_ker, ~1); + test(reg_tmp, 1); + jnz(do_store, T_NEAR); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + if (jcp.ver == ver_4fma) { + vaddps(zmm, ker_addr(ic1)); + } else { + assert(!"unknown convolution version"); + } + } + + L(do_store); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vmovups(ker_addr(ic1), zmm); + } + + mov(reg_ker, reg_tmp); + add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + add(reg_kw, 1); + cmp(reg_kw, jcp.kw); + jl(kw_loop); + } + + sub(reg_inp, reg_ihs); + sub(reg_out, reg_ohs); + + + L(kh_loop_end); + add(reg_kh, 1); + cmp(reg_kh, jcp.kh); + jl(kh_loop); + } + }; + + mov(reg_inp, ptr[param + GET_OFF(src)]); + mov(reg_out, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]); + mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]); + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + or_(reg_ker, reg_tmp); + + bool single_kh_kw_loop = (h_block_size == jcp.oh); + + size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in; + size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad); + size_t inp_block_step = inp_row_step * h_block_size; + size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in + * h_block_size; + + if (!single_kh_kw_loop) { + // Save the original prefetch pointers from the OpenMP driver + vmovq(reg_inp_pf_save, reg_inp_pf_l2); + vmovq(reg_out_pf_save, reg_out_pf_l2); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, first_inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + } + emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size); + + if (!single_kh_kw_loop) { + size_t ker_reset_offset + = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh; + sub(reg_ker, ker_reset_offset); + and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates + + add(reg_inp, first_inp_block_step); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + + int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2; + if (num_innermost_iters > 0) { + Label h_block_loop; + + mov(reg_tmp_w, num_innermost_iters); + kmovw(reg_h_block, reg_tmp_w); + L(h_block_loop); { + emit_kh_kw_loop(false, false, h_block_size); + sub(reg_ker, ker_reset_offset); + add(reg_inp, inp_row_step * h_block_size); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + kmovw(reg_tmp_w, reg_h_block); + sub(reg_tmp_w, 1); + kmovw(reg_h_block, reg_tmp_w); + jnz(h_block_loop); + } + } + + // Restore the original prefetch pointers that came from the OpenMP + // driver + vmovq(reg_inp_pf_l2, reg_inp_pf_save); + vmovq(reg_out_pf_l2, reg_out_pf_save); + emit_kh_kw_loop(false, true, h_block_size); + } + + return true; +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32 + ::flat_4ops_compute() { + const auto &j = jcp; + const bool ok = j.ver == ver_4fma && j.is_1stconv + && everyone_is(0, j.dilate_h, j.dilate_w); + if (!ok) return false; + + Reg64 reg_ptr_tr_src = r8; + Reg64 reg_ptr_dst = r9; + Reg64 reg_ptr_wei = r10; + Reg64 reg_ptr_bia = r11; + + Reg64 reg_kh_step = rax; + Reg64 reg_oh = abi_not_param1; + Reg64 reg_kh = rdx; + + Reg32 reg_flag_save = ebx; + Reg32 reg_flag = esi; + + Zmm vbia(31); + + auto zmm_wei = [&](int kh, int kw) { + return Zmm(8 + kh * j.kw + kw); + }; + auto zmm_dst = [&](int ow) { + return Zmm(ow % 8); + }; + + auto addr_tr_src = [&](int kh, int iw) { + return ptr[reg_ptr_tr_src + + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in]; + }; + auto addr_dst = [&](int ow) { + return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in]; + }; + auto addr_wei = [&](int kh, int kw) { + return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block + * jcp.typesize_out]; + }; + + auto emit_fma_block = [&](int kh_step) { + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + auto vwei = zmm_wei(kh, kw); + vpxord(vwei, vwei, vwei); + } + } + + for (int ow = 0; ow < j.ow; ow += 4) { + for (int _ow = ow; _ow < ow + 4; ++_ow) { + auto vdst = zmm_dst(_ow); + if (_ow < j.ow) + vmovups(vdst, addr_dst(_ow)); + else + vpxord(vdst, vdst, vdst); + } + + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + const int iw = ow + (kw % j.stride_w) * j.tr_ld + + (kw / j.stride_w); + v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow), + addr_tr_src(kh, iw)); + if (1 && kh == 0 && kw < 4) { + prefetcht1(ptr[reg_ptr_dst + + (j.ow + ow + kw) * jcp.oc_block + * jcp.typesize_in]); + } + if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */ + const int off = kw + 4 - j.kw; + if (off >= 0 && ow + off < j.ow) + vaddps(vbia, vbia, zmm_dst(ow + off)); + } + } + } + } + + Label l_store; + test(reg_flag, FLAG_MB_FIRST); + jnz(l_store, T_NEAR); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vaddps(zmm_wei(kh, kw), addr_wei(kh, kw)); + } + L(l_store); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vmovups(addr_wei(kh, kw), zmm_wei(kh, kw)); + } + }; + + auto emit_kh_loop = [&]() { + const int kh_step_rem = j.kh % j.kh_step; + xor_(reg_kh, reg_kh); + mov(reg_kh_step, j.kh_step); + + Label l_kh_loop; + L(l_kh_loop); { + Label l_done; + + if (kh_step_rem != 0) { + Label l_keep_kh_step; + cmp(reg_kh, j.kh - j.kh_step); + jle(l_keep_kh_step, T_NEAR); + + mov(reg_kh_step, kh_step_rem); + emit_fma_block(kh_step_rem); + jmp(l_done, T_NEAR); + + L(l_keep_kh_step); + } + + emit_fma_block(j.kh_step); + + L(l_done); + + add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out); + add(reg_kh, j.kh_step); + + cmp(reg_kh, j.kh); + jl(l_kh_loop, T_NEAR); + } + + const int kh_steps = rnd_up(j.kh, j.kh_step); + sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in); + sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out); + }; + + auto emit_oh_loop = [&]() { + mov(reg_oh, j.oh); + + Label l_oh_loop; + L(l_oh_loop); { + Label l_restore_mb_flag, l_jump; + + cmp(reg_oh, j.oh); + je(l_restore_mb_flag, T_NEAR); + + and_(reg_flag, ~FLAG_MB_FIRST); + jmp(l_jump, T_NEAR); + + L(l_restore_mb_flag); + mov(reg_flag, reg_flag_save); + + L(l_jump); + + emit_kh_loop(); + + add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in); + + dec(reg_oh); + jnz(l_oh_loop, T_NEAR); + } + }; + + auto emit_bia_store = [&]() { + if (!j.with_bias) return; + + Label l_bia_store, l_bia_skip; + test(reg_flag, FLAG_IC_FIRST); + jz(l_bia_skip); + + test(reg_flag, FLAG_MB_FIRST); + jnz(l_bia_store, T_NEAR); + vaddps(vbia, ptr[reg_ptr_bia]); + L(l_bia_store); + vmovups(ptr[reg_ptr_bia], vbia); + L(l_bia_skip); + }; + + mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]); + mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]); + mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]); + mov(reg_flag_save, ptr[param + GET_OFF(flags)]); + + vpxord(vbia, vbia, vbia); + emit_oh_loop(); + emit_bia_store(); + + return true; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop() +{ + if (flat_4ops_compute()) + return; + if (compute_full_spat_loop()) + return; + + maybe_zero_kernel(); + + if (jcp.ndims == 5) compute_d_loop_common(); + else compute_oh_loop_common(); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::generate() +{ + preamble(); + + mov(reg_input, ptr[param + GET_OFF(src)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + mov(reg_kernel, ptr[param + GET_OFF(filt)]); + + compute_loop(); + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) { + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper diff_weights_d(&diff_weights_md); + const memory_desc_wrapper diff_bias_d(&diff_bias_md); + const memory_desc_wrapper diff_dst_d(&diff_dst_md); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero(); + + jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1); + bool ok = true + // general condition to simplify dilations + && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) + && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) + && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) + // special condition to simplify dilations in compute_oh_loop_common + && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih); + if (!ok) + return status::unimplemented; + + jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1)); + jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1)); + + /* XXX: currently, does not support dilation_d > 0 */ + if (ndims == 5) + if (jcp.dilate_d > 0) + return status::unimplemented; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.aligned_threads = 0; + + /* check for the 1st convolution */ + jcp.is_1stconv = is_1stconv(jcp); + + jcp.oc_block = jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) + jcp.oc = rnd_up(jcp.oc, jcp.simd_w); + + if (jcp.oc % jcp.oc_block) + return status::unimplemented; + + auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); + + if (diff_dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + /* conditions on bias memory */ + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (diff_bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md, x)); + } + + jcp.nb_oc = jcp.oc / jcp.oc_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2; + const bool boundaries_ok = true + && jcp.t_pad <= max_pad + && jcp.b_pad <= max_pad + && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad) + && jcp.f_pad < jcp.kd; + if (!boundaries_ok) + return status::unimplemented; + + /* yet another common check */ + if (jcp.kw > 14) + return status::unimplemented; + + /* setting register strategy */ + for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; } + } + + if (jcp.is_1stconv) { + auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + if (jcp.ic == 1 && jcp.src_tag != src_tag) + jcp.src_tag = src_d.matches_one_of_tag( + pick(ndims - 3, nwc, nhwc, ndhwc)); + } + if (jcp.src_tag == format_tag::undef) + return status::unimplemented; + + const bool src_ok = true + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type()) + && one_of(jcp.ic, 1, 2, 3) + && jcp.ngroups == 1; + if (!src_ok) + return status::unimplemented; + + const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, + jcp.stride_w), 16); + const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1); + const int kh_step_rem = jcp.kh % kh_step; + + const auto wei_4fma_tag = with_groups + ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o); + + auto current_wei_tag = format_tag::undef; + if (diff_weights_d.format_kind() != format_kind::any) + current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag); + + const bool use_4fma = true + && one_of(ndims, 3, 4) + && mayiuse(avx512_mic_4ops) + && mkldnn_thr_syncable() + && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad) + && jcp.kw <= 28 - jcp.with_bias + && jcp.stride_w == 4 + && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */ + && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */ + && IMPLICATION(diff_weights_d.format_kind() != format_kind::any, + current_wei_tag == wei_4fma_tag); + + if (use_4fma) { + jcp.ver = ver_4fma; + jcp.kh_step = kh_step; + jcp.tr_ld = tr_ld; + jcp.ic_block = 1; + if (diff_weights_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag)); + jcp.wei_tag = wei_4fma_tag; + } else { + jcp.ver = ver_fma; + jcp.ic_block = jcp.ic; + + wei_tag = with_groups + ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + jcp.nb_ic = jcp.ic / jcp.ic_block; + } else { + auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + jcp.ic_block = jcp.simd_w; + if (ok_to_pad_channels) + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + if ((mayiuse(avx512_mic) || mayiuse(avx512_core)) + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type())) { + jcp.ver = ver_fma; + if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 && + everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && + mkldnn_thr_syncable()) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + if (jcp.ver == ver_4fma) { + jcp.ur_w = jcp.ow; + // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to + // cross the right boundary. The only requirement is not to have + // NaNs there because another multiplicand is always guaranteed to + // be zero. This also may require the top-level driver to allocate + // four extra guarding elements at the very end of the buffer. + // I'm not proud of this hack, but it improves performance by + // about 5-10% depending on the dimensions (Roma) + + const int tr_round = 4; + + jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round); + jcp.tr_src_num_guard_elems = tr_round; // upper bound + } + } + + if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) { + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + } else + return status::unimplemented; + + bool args_ok = true + && jcp.ic % jcp.ic_block == 0 + && jcp.oc % jcp.oc_block == 0 + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + { // balancing + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); + jcp.nthr = nthr; + jcp.nthr_mb = nthr_mb; + jcp.nthr_g = nthr_g; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + + return status::success; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.ver == ver_4fma) { + if (jcp.is_1stconv) { + const size_t tr_src_size = + jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } else { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic; + const size_t min_tr_src_size_per_thr + = jcp.ih * jcp.ic_block * jcp.tr_iw; + const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr + + jcp.tr_src_num_guard_elems; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } + + /* prepare synchronization contexts */ + if (jcp.nthr_oc_b > 1) { + const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * tr_src_bctx_size); + } + } + + if (jcp.nthr_mb > 1) { + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic + * jcp.kh * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const size_t wei_bia_reduction_size = wei_size + bia_size; + + scratchpad.book(key_conv_wei_bia_reduction, + jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1)); + scratchpad.book(key_conv_wei_bia_reduction_bctx, + sizeof(simple_barrier::ctx_t)); + } + + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::balance( + const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_, + int &nthr_oc_b_, int &nthr_ic_b_) +{ + nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; + + const int max_threads = mkldnn_get_max_threads(); + + if (max_threads < j.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + + if (!mkldnn_thr_syncable() && j.ver == ver_4fma) { + // should not happen -- the driver is not ready + // for TBB-like non-synchronous threading yet + return; + } + + if (j.ver == ver_4fma && j.is_1stconv) { + nthr_g_ = 1; + nthr_oc_b_ = 1; + nthr_ic_b_ = nstl::min(j.nb_ic, max_threads); + nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb); + nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_; + return; + } + + nthr_g_ = j.ngroups; + const int nthr = max_threads / nthr_g_; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level optimizer + * tries to minimize memory consumption. few notes: + * (n1) unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + + const int src_coef = j.ver == ver_4fma ? 4 : 1; + const int dst_coef = 1; + const int wei_coef = 8; + + return 0 + + src_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id + / j.stride_d / j.stride_h / j.stride_w /* (n1) */ + + dst_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od + + wei_coef /* (n2) */ + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) + * j.kh * j.kw * j.kd * j.ic_block * j.oc_block; + }; + + int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, j.mb * j.od); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + + if (!mayiuse(avx512_mic)) { + auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + return 1 + * div_up(j.mb, nthr_mb) + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) + * div_up(j.nb_ic, nthr_ic_b); + }; + + /* step 2: search for a thread distribution with lower compute cost. + * the constrains: + * - memory cost cannot exceed 110% of the best found in the step 1 + * - unless compute cost is 133% lower than the current best case + * note: both constants were found empirically */ + int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + const bool opt1 = comp_cost <= best_comp_cost + && mem_cost < 1.1 * best_mem_cost; + const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost; + + if (opt1 || opt2) { + best_comp_cost = comp_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + } + + if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads) + nthr_mb_ = nstl::min(j.mb * j.od, max_threads); + nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; + + assert(nthr_ <= max_threads); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1)); +} + +template struct _jit_avx512_common_conv_fwd_kernel; +template struct _jit_avx512_common_conv_fwd_kernel; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp new file mode 100644 index 0000000000..f76770797a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { + + _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_common_conv_fwd_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_inp = r8; + reg64_t reg_ker = r9; + reg64_t reg_out = r10; + + reg64_t reg_inp_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_out_prf = r13; + reg64_t reg_owb = r12; + + reg64_t aux_reg_inp = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_inp_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t reg_channel = rsi; + reg64_t reg_bias = rdx; + + reg64_t aux_reg_ker_d = r9; + reg64_t aux_reg_inp_d = rbx; + reg64_t aux_reg_inp_d_prf = r13; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_relu_ns = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_tmp = rbp; + + reg64_t reg_ic_loop = rdx; + reg64_t reg_inp_loop = rsi; + + reg64_t reg_init_flag = r13; + reg64_t reg_bias_ptr = param; + + reg64_t aux_reg_ic = r12; + reg64_t reg_binp = rax; + reg64_t reg_bout = r11; + reg64_t aux1_reg_inp = rbx; + reg64_t aux_reg_out = abi_not_param1; + + reg64_t reg_long_offt = r11; + reg64_t reg_out_long_offt = r14; + + inline Vmm vmm_ker(int i_ic) { + assert(i_ic < 4); + return Vmm(ker_reg_base_idx + i_ic); + } + + inline Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Vmm(idx); + } + + inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + + Xbyak::Reg64 imm_addr64 = r15; + Vmm vmm_wei = Vmm(31); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r); + inline void compute_loop(int ur_w, int pad_l, int pad_r); + + void generate(); + + inline size_t get_output_offset(int oi, int n_oc_block) { + return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh + * jcp.ow * jcp.od + oi) * jcp.oc_block; + } + + inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { + size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1; + size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id; + return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1) + + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str); + } + + inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) { + return jcp.typesize_in * jcp.oc_block + * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd + + (ic + ker_number) + ki * jcp.ic_block); + } + + inline int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + + inline int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } +}; + +struct jit_avx512_common_conv_fwd_kernel { + + jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_common_conv_fwd_kernel() { + delete xmm_kernel_; + delete zmm_kernel_; + } + + enum { + typesize = sizeof(float) + }; + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + void(*jit_ker)(jit_conv_call_s *); + _jit_avx512_common_conv_fwd_kernel *zmm_kernel_; + _jit_avx512_common_conv_fwd_kernel *xmm_kernel_; +}; + +struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator { + + jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_dst = r8; + reg64_t reg_ker = r9; + reg64_t reg_src = r10; + + reg64_t reg_dst_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_src_prf = r13; + + reg64_t aux_reg_dst = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_dst_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t aux_reg_dst_d_prf = r13; + reg64_t aux_reg_dst_d = rbx; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t aux_reg_ker_d = r9; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_channel = rsi; + + reg64_t reg_tmp = rbp; + reg64_t reg_long_offt = r14; + + inline Xbyak::Zmm zmm_ker(int i_ic) { + assert(i_ic < 4); + return Xbyak::Zmm(ker_reg_base_idx + i_ic); + } + inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Xbyak::Zmm(idx); + } + + Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { + + jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, + memory_desc_t &diff_dst_md); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum {typesize = sizeof(float)}; + static const int max_ur_w; + + reg64_t param = abi_param1; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t reg_tmp = r14; + reg64_t reg_long_offt = r14; + + reg64_t ki = r11; + reg64_t reg_kd_count = r12; + reg64_t reg_oi = r12; + reg64_t reg_d_index = r13; + reg64_t reg_input_d = r15; + reg64_t reg_output_d = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t reg_bias = rbx; + + inline void bias_kernel(); + inline void maybe_zero_kernel(); + inline void compute_oh_step_unroll_ow_icblock(int ic_block_step, + int max_ur_w); + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_ic_block_step(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound = false); + inline void compute_ic_block_step_fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_ic_block_step_4fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_step_disp(); + inline void compute_oh_loop_common(); + inline void compute_d_loop_common(); + + inline bool compute_full_spat_loop(); + inline bool flat_4ops_compute(); + + inline void compute_loop(); + + void generate(); + + static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, + int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp new file mode 100644 index 0000000000..1bdcd0d6a8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp @@ -0,0 +1,1163 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +#ifndef KERNEL_SIZE_THRESHOLD +#define KERNEL_SIZE_THRESHOLD 16 +#endif + +#define MIN_REQUIRED_DIMN_REG_BLOCK 14 + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + if (jcp.ver == ver_4fma) + return jcp.mb >= 32; + else + return jcp.mb >= 16; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: + case L3: + default: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block, + float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate( + bool is_beta_zero) +{ + // const int dimK_simd_block = jcp.dimK_reg_block; + + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][tile] += + // A[dimM_block][dimK_block][dimK_reg_block] * + // broadcast(B[dimK_block][tile][dimK_reg_block]); + // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block], + // so we load it before the loop on tile + // 2) the loop on tile must be fully unrolled. Don't know about the one on + // dimK_reg_block. I think it should be + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + + prefetcher_t L1_pf(this, reg_srcB, L1, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + prefetcher_t L2_pf(this, reg_srcB, L2, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { + // First, we zero the accumulators if first nb_ic iteration, + // otherwise we load them + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (is_beta_zero) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + 64 * tile]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_dimK_reg_block; i++) + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + 64 * (offset + i)]); + }; + + // Used when doing double buffering + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block += inc_dimK_reg_block) { + int current; + /* Loading the next vector from A */ + current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block + inc_dimK_reg_block) + % (2 * inc_dimK_reg_block); + load_A(next, dimK_reg_block + inc_dimK_reg_block); + } else { + next = 0; + load_A(next, dimK_reg_block); + } + /* Performing the fmas */ + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (jcp.ver != ver_avx512_core) + L1_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + if (jcp.ver == ver_4fma) + v4fmaddps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4)); + else + vfmadd231ps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, + true)); + if (jcp.ver != ver_avx512_core) + L2_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + } + } + + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + } + + + auto store_output = [=](bool output_is_aligned) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(zword[reg_dstC + 64 * tile], zmm); + else + vmovups(zword[reg_dstC + 64 * tile], zmm); + } + }; + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimN_reg_block * 64); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + + if (mayiuse(avx512_core)) + return status::unimplemented; + else if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // Checking conditions not supported by these kernels + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + return status::success; +} + + +status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) { + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK) + && (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block < current_best); + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block); + + if (jcp.dimN_reg_block >= jcp.nb_reg) { + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block > current_best); + }; + + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, 1, test_cond_dimN_reg_block); + } + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + jcp.dimM_simd_block = 16; + /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/ + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_simd_block, .5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + // TODO: replace double buffering with nuple buffering to maximize register + // usage. + // the choice of the number of buffers will then come after choosing + // dimN_reg_block + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2); + else + jcp.zmm_start = 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + + jcp.sched_policy = WSCHED_INVALID; + set_wsched_DATA_W_S_G_D_avx512_common(jcp); + + assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D); + return status::success; +} + +bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) || + (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { + status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate() +{ + auto load_B = [=](int reg_idx, int offset) { + for (int i = 0; i < 4; i++) { + vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]); + } + }; + + preamble(); + int curr = 0; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int origB_offset = (j * alpha + i) * jcp.dimK_4fma; + size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block * + jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float); + mov(reg_transB_idx, transB_offset); + for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) { + /*double buffering to hide load latencies*/ + int next = (curr + 4) % 8; + if (i == 0 && tb == 0) { + load_B(0, origB_offset); + } + if (tb + 4 < (jcp.dimK_4fma -1)) { + load_B(next, origB_offset + 4); + } else if (i < alpha - 1) { + load_B(next, origB_offset + jcp.dimK_4fma); + } + + vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3)); + vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1)); + vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3)); + + vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9)); + vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9)); + + vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1)); + + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * tb * jcp.dimN_reg_block], + Zmm(curr+2)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 1) * jcp.dimN_reg_block], + Zmm(curr+3)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 2) * jcp.dimN_reg_block], + Zmm(8)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 3) * jcp.dimN_reg_block], + Zmm(9)); + curr = next; + + } + } + } + postamble(); + ret(); +} +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate( + bool is_first_tile) +{ + // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) + // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++) + // for (int nb_tile_block_ur = 0; nb_tile_block_ur < + // jcp.nb_tile_block_ur; nb_tile_block_ur++) + // for (int tile_block_ur = 0; tile_block_ur < + // jcp.tile_block_ur; tile_block_ur++) + // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3) + // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] += + // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block] + // * + // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur]) + auto inner_loops = [=]() { + int inc_fma = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + prefetcher_t L1_pf(this, reg_srcB, L1, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + prefetcher_t L2_pf(this, reg_srcB, L2, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_fma; i++) { + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + + sizeof(float) * jcp.dimM_simd_block * (offset + i)]); + } + }; + + Label dimM_block_loop, dimK_block_loop, dimN_block_loop; + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { /************* OC_block (M) loop ***********/ + if (jcp.dimN_block > 1) { + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + } + { /*************** IC_block (N) loop *********/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + if (is_first_tile) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * + sizeof(float)]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { /************* nb_tile_ur(K) loop ********/ + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block++) { + int srcB_offset = dimK_reg_block * jcp.dimK_4fma + * jcp.dimN_reg_block; + for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma; + dimK_4fma += inc_fma) { + int current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma) + % (2 * inc_fma); + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma); + } else { + next = 0; + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma); + } + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; + ++dimN_reg_block) { + L1_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + L2_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + if (jcp.ver == ver_4fma) { + int srcB_trans_offset = (dimK_4fma / 4) * 64 + + dimK_4fma % 4; + v4fmaddps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * ( + srcB_offset + + srcB_trans_offset + + (dimN_reg_block % 4) * 16 + + (dimN_reg_block / 4) * 4))); + } else { + vfmadd231ps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * (srcB_offset + dimN_reg_block), + true)); + } + } + } + } + } + + add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma + * jcp.dimM_simd_block * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + /******** Write C back to memory *******/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + vmovups(zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * sizeof(float)], + zmm); + } + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block + * sizeof(float)); + if (jcp.dimN_block > 1) { + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + } + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimN_block * jcp.dimK_block + * jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + // register used to handle long fma encoding + preamble(); + mov(reg_srcA, reg_srcA_const); + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { +bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw; + lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma + * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block + * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} + +bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block; + lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} +} // namespace + +status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp) +{ + /*************** Choose dimN_reg_block (ic_simd_block) + * *******************************/ + jcp.dimN = jcp.ic; + /*Hardcoded to 16 because N = ic for bwd weights and + innermost dimension for ic is assumed 16 in src transforms. This + choice covers load latencies while maintaining simplicity of kernel + for POR topologies. FIXME in future??: Will not work for future topologies + when ic%16 != 0*/ + jcp.dimN_reg_block = jcp.ic_simd_block; + + /****************************** Choose dimK_block + * **************************/ + // No freedom for choosing dimM_simd_block because ic_simd_block + // is determined by input data format + jcp.dimM_simd_block = jcp.oc_simd_block; + + auto test_cond1bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond2bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f) + && (dimK_block > current_best); + }; + + auto test_cond2_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block); + if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block); + + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block); + if (jcp.dimK_reg_block < jcp.dimK_block) { + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1_dimK_block); + } + jcp.dimK_block /= jcp.dimK_reg_block; + jcp.dimK_nb_block + = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + /***************************** Chose dimN block + * ****************************/ + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block, + jcp.dimN_reg_block, 0.5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.ic_block = jcp.dimN_block; + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /********************************* Choose dimM block + * ************************/ + jcp.dimM = jcp.oc; + + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1, + jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block, + 1.0f) + && (dimM_block > current_best) + && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2; + }; + + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + jcp.sched_policy = WSCHED_WEI_S_D_G_W; + return status::success; +} + +status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) +{ + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (mayiuse(avx512_core)) + return status::unimplemented; + if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /*************************** New Kernel Parameters + * *****************************/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + jcp.dimK_4fma = 1; + jcp.tile_4fma_padding = 0; + +#define MAX_4FMA_UR 8 + if (jcp.ver == ver_4fma) { + auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma, + int current_best) { + return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR) + && (dimK_4fma > current_best); + }; + jcp.dimK_4fma = get_divisor_satisfying_cond( + jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma); + if (jcp.dimK_4fma == 1) + jcp.dimK_4fma = 4; + if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0) + jcp.tile_4fma_padding = jcp.dimK_4fma + - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma); + } + + jcp.tile_4fma = jcp.dimK_4fma; + /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src + * transform + * will not work correctly, this is solved by applying padding.*/ + jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2; + else + jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp); + assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W); + + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + + return res; + +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp new file mode 100644 index 0000000000..6c117143f5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp @@ -0,0 +1,179 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +//alpha determines the output tile_size +constexpr int alpha = 6; +constexpr int tile_size = 4; +//simd length used for vectorization +constexpr int simd_w = 16; + +struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator { + _jit_avx512_common_conv_winograd_data_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter + = (decltype(gemm_loop_ker_first_iter)) this->getCode(); + + //************** Subsequent iterations kernel **************// + if (jcp.dimK_nb_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + +protected: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_beta_zero); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; +}; + +struct jit_avx512_common_conv_winograd_fwd_kernel_f32 + : _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); +}; + +struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32 + : public _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + + //******************* First iter kernel ********************// + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr; + } + + if (jcp.tile_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + if (jcp.ver == ver_4fma) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->transpose_ker_generate(); + transpose_4fma_ker = (decltype(transpose_4fma_ker))addr; + } + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*transpose_4fma_ker)(float *, float *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_first_tile); + void transpose_ker_generate(); + + reg64_t reg_origB = abi_param2; + reg64_t reg_transB = abi_param1; + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA_const = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_sp = rsp; + reg64_t reg_srcA = r9; + reg64_t reg_nb_ic = r10; + reg64_t reg_loop_cpt = r11; + reg64_t reg_transB_idx = r13; + + /* Registers used by new kernel */ + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r12; + reg64_t reg_dimN_block_loop_cnt = r11; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp new file mode 100644 index 0000000000..abddc19221 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp @@ -0,0 +1,1526 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define PIPELINE(field) \ + do { \ + p.field = p.field ## _prf; \ + p.field ## _prf = field; \ + } while (0) + +inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int kd_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker, + jit_conv_call_s &p, const void *src, const void *dst, const void *filt, + const void *bias, int channel, int kh_padding, int kd_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int d_index, int d_worksize, + int kd_padding /* kd_work_size */, size_t kd_offset) { + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kd_padding); + PIPELINE(d_worksize); + PIPELINE(d_index); + PIPELINE(kd_offset); + + if (p.src) + ker(&p); +} +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template +void jit_avx512_common_convolution_fwd_t::prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const { + if (!pd()->wants_padded_bias()) return; + + auto padded_bias = scratchpad.template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding); + utils::array_set(padded_bias + pd()->jcp_.oc_without_padding, + (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding); + bias = padded_bias; +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + auto bias_w = bias ? bias + g_oc : nullptr; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src_w, dst_w, wht_w, bias_w, icb, 1, owb); + + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + auto bias_w = bias ? bias + g_oc : nullptr; + + for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) { + int ih_b = -jcp.t_pad + oh_b * jcp.stride_h; + + auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s); + auto src_w + = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s); + auto wht_w + = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); + ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_b, ij = ih_b; + oj < min(oh_e, oh_b + jcp.h_blocking); + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up(max(0, ij - jcp.ih + + (jcp.kh - 1) * dilate_h + 1), dilate_h); + int kh_padding = nstl::max( + 0, jcp.kh - i_t_overflow - i_b_overflow); + + auto aux_src = src_c + + i_t_overflow * dilate_h * src_h_stride; + auto aux_wht = wht_w + i_t_overflow * wht_h_stride; + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, aux_src, dst_c, aux_wht, bias_w, icb, + kh_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_3d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + parallel(0, [&](const int ithr, const int nthr) { + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int start{0}, end{0}, start_copy; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh + * jcp.nb_ow; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_d_stride = src_d.blk_off(0, 0, 1); + size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + int id_s = -jcp.f_pad + od_s * jcp.stride_d; + + int dilate_d = jcp.dilate_d + 1; + int d_t_overflow = div_up(max(0, -id_s), dilate_d); + int d_b_overflow = div_up( + max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), + dilate_d); + int kd_padding = nstl::max(0, + jcp.kd - d_t_overflow - d_b_overflow); + + auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s, + iw_s) + d_t_overflow * dilate_d * src_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2) + + d_t_overflow * wht_d_stride; + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_s, ij = ih_s; + oj < oh_e; ++oj, ij += jcp.stride_h) + { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + + 1), + dilate_h); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, + src_c + i_t_overflow * dilate_h * src_h_stride, + dst_c, wht_w + i_t_overflow * wht_h_stride, + bias_w, icb, kh_padding, kd_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template struct jit_avx512_common_convolution_fwd_t; + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_1d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}; + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n, + jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w, diff_dst_w, wht_w, 0, ocb, 1); + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups, + n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_2d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_3d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1); + size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1); + size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1; + bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + int d_len = 0, d_lo = 0, d_oj = 0; + if (is_fast_path_d) { // dilate == 0 && stride == 1 + int d_t_overflow = max(0, jcp.kd - 1 - id_s + - jcp.f_pad); + int d_b_overflow = max(0, jcp.kd - jcp.id + id_s + - jcp.back_pad); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow; + } else if (jcp.dilate_d != 0) { // stride == 1 + int dilate_d = jcp.dilate_d + 1; + // Note: use div_up to account for "holes" in filter + int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + - id_s - jcp.f_pad), dilate_d); + int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1 + - jcp.id + id_s - jcp.back_pad), dilate_d); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d; + } else { // dilate == 0 + int d_t_overflow = max(0, (jcp.kd - 1 - id_s + - jcp.f_pad) / jcp.stride_d); + int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s + - jcp.back_pad) / jcp.stride_d); + int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1 + + jcp.back_pad - id_s) % jcp.stride_d); + int overflow_kd_lo = (id_s + jcp.f_pad) + % jcp.stride_d; + + d_len = (overflow_kd_hi - overflow_kd_lo) + / jcp.stride_d + 1 - d_t_overflow + - d_b_overflow; + d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d; + d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d; + } + assert(d_len >= 0); + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb) + + id_s * diff_src_d_stride; + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2) + + d_oj * diff_dst_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb) + + d_lo * wht_d_stride; + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path_h) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len, d_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1, 1); + }); +} + +template struct jit_avx512_common_convolution_bwd_data_t; + +template +jit_avx512_common_convolution_bwd_weights_t:: +jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) + , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) +{ + const auto &j = pd()->jcp_; + + nthr_ = j.nthr; + nthr_mb_ = j.nthr_mb; + nthr_g_ = j.nthr_g; + nthr_oc_b_ = j.nthr_oc_b; + nthr_ic_b_ = j.nthr_ic_b; + + kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j); + + if (j.ver == ver_4fma) + trans_kernel_ = create_trans_src(&j); + + if (nthr_mb_ > 1) + acc_ker_ = new cpu_accumulator_1d_t(); + + reducer_bias_ = + new cpu_reducer_t(pd()->reducer_bia_conf_); +} + +template +struct jit_avx512_common_convolution_bwd_weights_t::thread_info_t { + const src_data_t *src; + const diff_dst_data_t *diff_dst; + const diff_weights_data_t *diff_weights; + diff_weights_data_t *diff_bias; + + const memory_tracking::grantor_t scratchpad; + + src_data_t *tr_src; + simple_barrier::ctx_t *tr_src_bctx; + + diff_dst_data_t *tr_diff_dst; + simple_barrier::ctx_t *tr_diff_dst_bctx; + + diff_weights_data_t *wei_bia_reduction; + simple_barrier::ctx_t *wei_bia_reduction_bctx; + + int ithr; + int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb; + int ithr_but_oc; + int ithr_but_ic; + + int img_start = 0, img_end = 0, img_work; + int g_start = 0, g_end = 0, g_work; + int oc_b_start = 0, oc_b_end = 0, oc_b_work; + int ic_b_start = 0, ic_b_end = 0, ic_b_work; + + thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self, + const exec_ctx_t &ctx, int ithr) + : scratchpad(self->scratchpad(ctx)), ithr(ithr) + { + diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + diff_bias = self->pd()->wants_padded_bias() + ? scratchpad.template get( + key_conv_padded_bias) + : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + + tr_src = scratchpad.template get(key_conv_tr_src); + tr_src_bctx = scratchpad.template get( + key_conv_tr_src_bctx); + + tr_diff_dst = scratchpad.template get( + key_conv_tr_diff_dst); + tr_diff_dst_bctx = scratchpad.template get( + key_conv_tr_diff_dst_bctx); + + wei_bia_reduction = scratchpad.template get( + key_conv_wei_bia_reduction); + wei_bia_reduction_bctx = scratchpad.template get( + key_conv_wei_bia_reduction_bctx); + + ithr_ic_b = ithr % self->nthr_ic_b_; + ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; + ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; + ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; + + ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ + + ithr_ic_b; + + ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ + + ithr_oc_b; + + const auto &jcp = self->kernel_->jcp; + + /* reduction dimension */ + balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end); + img_work = img_end - img_start; + + /* independent dimensions */ + balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); + g_work = g_end - g_start; + + balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, + oc_b_end); + oc_b_work = oc_b_end - oc_b_start; + + balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, + ic_b_end); + ic_b_work = ic_b_end - ic_b_start; + } +}; + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + // TODO: use memory descriptor with the same fmt as src (or use a macro :)) + auto tr_src_off = [&](int ithr_mb, int ic, int ij) { + const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; + const size_t tr_chn_size = tr_row_size * jcp.ih; + const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups; + + return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size; + }; + + auto uker_trans = [&](int img) { + const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih; + + int start{0}, end{0}; + balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end); + const int my_work = end - start; + + int g{0}, ic_b{0}, j{0}; + nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih); + g += ti->g_start; + ic_b += ti->ic_b_start; + + const int _ic = g * jcp.nb_ic + ic_b; + src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)]; + src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.iw * jcp.ic_block; + const int tr_src_stride = jcp.tr_iw * jcp.ic_block; + + const int pf_depth = 2; + struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth]; + + for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) { + pf_circ_buf[iwork % pf_depth] = {src1, tr_src1}; + + if (iwork >= pf_depth - 1) { + int old_idx = (iwork - pf_depth + 1) % pf_depth; + auto ctx = jit_trans_src_t::ctx_t(); + ctx.src = pf_circ_buf[old_idx].src; + ctx.tr_src = pf_circ_buf[old_idx].tr_src; + ctx.src_prf = src1; + ctx.tr_src_prf = tr_src1; + (*trans_kernel_)(&ctx); + } + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#if 0 + // reference transposition + const int l_pad = jcp.l_pad; + const int iwlp = l_pad + jcp.iw; + const int tr_iw = jcp.tr_iw; + + for (size_t iwork = start; iwork < end; iwork++) { + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = 0; i < l_pad; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = l_pad; i < iwlp; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] + = (src_data_t)src1[(i - l_pad) * 16 + j]; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = iwlp; i < tr_iw; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#endif + }; + + if (jcp.is_1stconv && jcp.ver == ver_4fma) { + /* prepare contexts */ + auto tr_ctx = jit_trans_src_t::ctx_t(); + tr_ctx.tr_src = ti->tr_src + + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld; + + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1)); + tr_ctx.nthr_oc_b = nthr_oc_b_; + int ih_start{0}, ih_end{0}; + balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end); + tr_ctx.tr_src_ih_start = ih_start; + tr_ctx.tr_src_ih_end = ih_end; + tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc; + + auto p = jit_conv_call_s(); + p.src = tr_ctx.tr_src; + + /* zero diff_bias if applicable */ + if (jcp.with_bias && ti->ithr_ic_b == 0) { + assert(jcp.oc_block == 16); + for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) { + diff_weights_data_t *db = &diff_bia[oc_b * 16]; + for (int o = 0; o < 16; ++o) + db[o] = 0; + } + } + + for (int img = ti->img_start; img < ti->img_end; ++img) { + p.flags = (img == ti->img_start) * FLAG_MB_FIRST; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _ic = g * jcp.nb_ic + ic_b; + tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)]; + + (*trans_kernel_)(&tr_ctx); + + if (ic_b == 0) + p.flags |= FLAG_IC_FIRST; + else + p.flags &= ~FLAG_IC_FIRST; + + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + const int _oc = g * jcp.nb_oc + oc_b; + p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + + const size_t off = + wht_blk_off(diff_weights_d, g, oc_b, ic_b); + p.filt = diff_wei + off; + p.bias = diff_bia + _oc * jcp.oc_block; + + kernel_->jit_ker(&p); + } + } + } + } + } else { + for (int img = ti->img_start; img < ti->img_end; ++img) { + auto p = jit_conv_call_s(); + + if (jcp.ver == ver_4fma) { + /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ + using simple_barrier::barrier; + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + uker_trans(img); + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + } + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img, _oc)], + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + 0, (img == ti->img_start), 0); + + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off( + diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + 0, 0, 0); + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights_3d(const thread_info_t *ti) const +{ + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size + = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + const int input_step = jcp.ih * jcp.iw * inp_mult; + const int output_step = jcp.ow * jcp.oh * jcp.oc_block; + int img{0}, od_s{0}; + int img_start = ti->img_start, img_end = ti->img_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + while (img_start < img_end) { + auto p = jit_conv_call_s(); + + int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int ik_overlap = nstl::max(0, id_s - jcp.f_pad); + const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s); + const int kd_back_pad + = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd); + int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw + * jcp.ic_block * jcp.oc_block * jcp.typesize_out; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + auto src = &ti->src[src_d.blk_off(img, _ic) + + ik_overlap * input_step]; + auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc) + + od_s * output_step]; + + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst, + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + diff_bia + _oc * 16, (img == img_first), od_s, od_e, + jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off); + + if (ic_b == 0) p.flags = 0; + else p.flags = 1; + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, + &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off(diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + diff_bia, 0, 0, 0, 0, 0); + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kh; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh; + const int kh = sub_ic_b_kh_start % jcp.kh; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh); + + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) { + if (ti->ithr == 0) + acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias, + diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights_3d(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw + * jcp.kd; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kd; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd; + const int kd = sub_ic_b_kh_start % jcp.kd; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd); + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias(const thread_info_t *ti) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + auto rb = this->reducer_bias_; + assert(nthr_ == rb->balancer().nthr_); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t( + ti->scratchpad, prefix_reducer_bia); + + const auto &jcp = kernel_->jcp; + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return; + + const int b_job_start = rb->balancer().ithr_job_off(ti->ithr); + const int b_njobs = rb->balancer().ithr_njobs(ti->ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ti->ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc); + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const diff_dst_data_t *d_dst + = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr, + ti->diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0; + for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + + rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad); +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias_3d(const thread_info_t *ti) const { + + const auto &jcp = kernel_->jcp; + + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh + * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size; + + if (nthr_mb_ > 1) mkldnn_thr_barrier(); + + if (ti->ithr == 0) + { + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::prepare_scratchpad_data(const exec_ctx_t &ctx) const +{ + const auto &j = pd()->jcp_; + auto scratchpad = this->scratchpad(ctx); + + if (j.ver == ver_4fma) { + if (!j.is_1stconv) { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic; + const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw; + + auto tr_src = scratchpad.template get(key_conv_tr_src); + /* to avoid NaNs in computations we zero tail num_guard_elems for + * each possible thread group */ + + for (int ithr = 1; ithr <= max_nthr; ++ithr) { + src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr]; + for (int i = 0; i < j.tr_src_num_guard_elems; ++i) + ts[i] = 0; + } + } + + if (j.nthr_oc_b > 1) { + const int tr_src_bctx_size = j.nthr / j.nthr_oc_b; + auto tr_src_bctx = scratchpad.template get( + key_conv_tr_src_bctx); + for (int i = 0; i < tr_src_bctx_size; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + } + + if (nthr_mb_ > 1) { + simple_barrier::ctx_init(scratchpad.template get( + key_conv_wei_bia_reduction_bctx)); + } + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); +} + +template +void jit_avx512_common_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { + prepare_scratchpad_data(ctx); + + parallel(nthr_, [&](const int ithr, const int nthr) { + assert(nthr_ == nthr); + + thread_info_t thread_info(this, ctx, ithr); + + if (utils::one_of(pd()->ndims(), 3, 4)) { + compute_diff_weights(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights(&thread_info); + if (pd()->with_bias()) compute_diff_bias(&thread_info); + } else if (pd()->ndims() == 5) { + compute_diff_weights_3d(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); + if (pd()->with_bias()) compute_diff_bias_3d(&thread_info); + } else { + assert(false); + } + }); + + /* TODO: put that into compute_diff_bias() */ + if (pd()->wants_padded_bias()) { + auto diff_bias = scratchpad(ctx).template get( + key_conv_padded_bias); + auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +template struct jit_avx512_common_convolution_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp new file mode 100644 index 0000000000..3341c3ebe0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_transpose_src_utils.hpp" +#include "jit_avx512_common_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad, + jcp_); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_common_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + ~jit_avx512_common_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_forward_1d(ctx); + else if (pd()->ndims() == 4) + execute_forward_2d(ctx); + else if (pd()->ndims() == 5) + execute_forward_3d(ctx); + else + assert(false); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); + + return status::success; + } + +private: + void prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_fwd_kernel *kernel_; +}; + +template +struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, + data_type::undef, diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_, + *desc(), *diff_src_md(), *weights_md(), *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i, + OIdhw16o16i, gOIdhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx512_common_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_backward_data_1d(ctx); + else if (pd()->ndims() == 4) + execute_backward_data_2d(ctx); + else if (pd()->ndims() == 5) + execute_backward_data_3d(ctx); + else + assert(false); + return status::success; + } + +private: + void execute_backward_data_1d(const exec_ctx_t &ctx) const; + void execute_backward_data_2d(const exec_ctx_t &ctx) const; + void execute_backward_data_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_; +}; + +template +struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_weights_type, + diff_weights_type, diff_dst_type, data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), src_md_, diff_weights_md_, + diff_bias_md_, diff_dst_md_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + return status; + } + + jit_conv_conf_t jcp_; + typename cpu_reducer_t::conf_t reducer_bia_conf_; + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + } + }; + + jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd); + ~jit_avx512_common_convolution_bwd_weights_t() { + delete kernel_; + if (trans_kernel_) + delete trans_kernel_; + if (acc_ker_) + delete acc_ker_; + delete reducer_bias_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type diff_weights_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + void prepare_scratchpad_data(const exec_ctx_t &ctx) const; + struct thread_info_t; + void compute_diff_weights(const thread_info_t *) const; + void compute_diff_weights_3d(const thread_info_t *) const; + void reduce_diff_weights(const thread_info_t *) const; + void reduce_diff_weights_3d(const thread_info_t *) const; + void compute_diff_bias(const thread_info_t *) const; + void compute_diff_bias_3d(const thread_info_t *) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_; + + jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_; + jit_trans_src_t *trans_kernel_; + cpu_accumulator_1d_t *acc_ker_; + cpu_reducer_t *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp new file mode 100644 index 0000000000..62247c0264 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp @@ -0,0 +1,1215 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution_winograd.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +namespace { + +unsigned int LLC_cache_size = get_cache_size(3, false); + +void inline load_ps(float *dest, const float *src_mem) { +#ifdef __INTEL_COMPILER + __m512 *Iv512 = (__m512 *)dest; + Iv512[0] = _mm512_load_ps(src_mem); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v]; +#endif +} + +void inline store_output(float *dest, const float *data, bool streamout) { +#ifdef __INTEL_COMPILER + if (streamout) + _mm512_stream_ps(dest, *((__m512 *)data)); + else + _mm512_store_ps(dest, *((__m512 *)data)); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} + +void inline accum_output( + float *dest, float *data, bool streamout, bool with_relu_postsum) { +#ifdef __INTEL_COMPILER + __m512 _data = _mm512_loadu_ps(data); + __m512 _dest = _mm512_loadu_ps(dest); + _data = _mm512_add_ps(_data, _dest); + if (with_relu_postsum) + _data = _mm512_max_ps(_data, _mm512_setzero_ps()); + if (streamout) + _mm512_stream_ps(dest, _data); + else + _mm512_store_ps(dest, _data); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + data[v] += dest[v]; + + if (with_relu_postsum) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + if (data[v] < 0.f) + data[v] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} +} + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) { + float Fw[6][16]; + float T[6][3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +#pragma unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * F[2][i][j][k]; + t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k]; + t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k]; + + T[0][i][k] = 1.13777777777778f * F[0][i][j][k]; + T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k]; + T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k]; + T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k]; + T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k]; + T[5][i][k] = F[2][i][j][k]; + } + } +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * T[i][2][k]; + t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k]; + t2[k] = t0[k] + 0.119514472455649f * T[i][0][k]; + + Fw[0][k] = 1.13777777777778f * T[i][0][k]; + Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k]; + Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k]; + Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k]; + Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k]; + Fw[5][k] = T[i][2][k]; +#pragma unroll + for (int l = 0; l < 6; l++) { + Fw_[i][l][j][k] = Fw[l][k]; + } + } + } + } +} + +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) { + float T[4][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][v] + Mw[2][i][v]; + t1[v] = Mw[3][i][v] + Mw[4][i][v]; + t2[v] = Mw[1][i][v] - Mw[2][i][v]; + t3[v] = Mw[3][i][v] - Mw[4][i][v]; + + T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v]; + T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f; + T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v]; + } + } +#pragma unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = T[i][1][v] - T[i][2][v]; + t3[v] = T[i][3][v] - T[i][4][v]; + + O[i][0][v] = t0[v] + t1[v] + T[i][0][v]; + O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f; + O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v]; + } + } +} + + +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]) +{ + const float rcp3 = 1.0f / 3.0f; + const float rcp4 = 1.0f / 4.0f; + const float rcp6 = 1.0f / 6.0f; + const float rcp12 = 1.0f / 12.0f; + const float rcp24 = 1.0f / 24.0f; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float T[6][4][16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = F[2][i][j] * rcp6; + t1[j] = F[0][i][j] * -rcp6 - t0[j]; + t2[j] = F[0][i][j] * rcp24 + t0[j]; + t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6; + t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3; + + T[0][i][j] = F[0][i][j] * rcp4; + T[1][i][j] = t1[j] - t3[j]; + T[2][i][j] = t1[j] + t3[j]; + T[3][i][j] = t2[j] + t4[j]; + T[4][i][j] = t2[j] - t4[j]; + T[5][i][j] = F[3][i][j]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = T[i][2][j] * rcp6; + t1[j] = T[i][0][j] * -rcp6 - t0[j]; + t2[j] = T[i][0][j] * rcp24 + t0[j]; + t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6; + t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3; + + Fw[i][0][j] = T[i][0][j] * rcp4; + Fw[i][1][j] = t1[j] - t3[j]; + Fw[i][2][j] = t1[j] + t3[j]; + Fw[i][3][j] = t2[j] + t4[j]; + Fw[i][4][j] = t2[j] - t4[j]; + Fw[i][5][j] = T[i][3][j]; + } + } +} + +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[4][6][16]; + float M_[3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l]; + t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l]; + t2[l] = t1[l] * 4.0f + Mw[5][i][j][l]; + + T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l]; + T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) + + 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]); + T[2][i][l] = t0[l] + t2[l]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = T[i][1][l] + T[i][2][l]; + t1[l] = T[i][3][l] + T[i][4][l]; + t2[l] = t1[l] * 4.0f + T[i][5][l]; + + M_[0][l] = T[i][0][l] + t0[l] + t1[l]; + M_[1][l] = (T[i][1][l] - T[i][2][l]) + + 2.0f * (T[i][3][l] - T[i][4][l]); + M_[2][l] = t0[l] + t2[l]; + + for (int k = 0; k < 3; k++) { + M[i][k][j][l] = M_[k][l]; + } + } + } + } +} + +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]) +{ + float T[6][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float t5[16]; + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = I[2][i][v] * -2.25f + I[4][i][v]; + t1[v] = I[1][i][v] * -2.25f + I[3][i][v]; + t2[v] = I[2][i][v] * -0.390625f + I[4][i][v]; + t3[v] = I[1][i][v] * -0.390625f + I[3][i][v]; + t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v]; + t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v]; + + T[0][i][v] = I[2][i][v] * -2.640625f + t4[v]; + T[1][i][v] = t1[v] * 0.625f + t0[v]; + T[2][i][v] = t1[v] * -0.625f + t0[v]; + T[3][i][v] = t3[v] * 1.5f + t2[v]; + T[4][i][v] = t3[v] * -1.5f + t2[v]; + T[5][i][v] = I[3][i][v] * -2.640625f + t5[v]; + } + } + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * -2.25f + T[i][4][v]; + t1[v] = T[i][1][v] * -2.25f + T[i][3][v]; + t2[v] = T[i][2][v] * -0.390625f + T[i][4][v]; + t3[v] = T[i][1][v] * -0.390625f + T[i][3][v]; + t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v]; + t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v]; + + Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v]; + Iw[i][1][v] = t1[v] * 0.625f + t0[v]; + Iw[i][2][v] = t1[v] * -0.625f + t0[v]; + Iw[i][3][v] = t3[v] * 1.5f + t2[v]; + Iw[i][4][v] = t3[v] * -1.5f + t2[v]; + Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v]; + } + } +} + +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]) +{ + float T[6][4][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = F[2][i][v] * 0.26890756302521f; + t1[v] = F[0][i][v] * -0.688403361344538f - t0[v]; + t2[v] = F[0][i][v] * 0.119514472455649f + t0[v]; + t3[v] = F[1][i][v] * 0.430252100840336f + + F[3][i][v] * 0.168067226890756f; + t4[v] = F[1][i][v] * 0.179271708683473f + + F[3][i][v] * 0.403361344537815f; + + T[0][i][v] = F[0][i][v] * 1.13777777777778f; + T[1][i][v] = t1[v] - t3[v]; + T[2][i][v] = t1[v] + t3[v]; + T[3][i][v] = t2[v] + t4[v]; + T[4][i][v] = t2[v] - t4[v]; + T[5][i][v] = F[3][i][v]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * 0.26890756302521f; + t1[v] = T[i][0][v] * -0.688403361344538f - t0[v]; + t2[v] = T[i][0][v] * 0.119514472455649f + t0[v]; + t3[v] = T[i][1][v] * 0.430252100840336f + + T[i][3][v] * 0.168067226890756f; + t4[v] = T[i][1][v] * 0.179271708683473f + + T[i][3][v] * 0.403361344537815f; + + Fw[i][0][v] = T[i][0][v] * 1.13777777777778f; + Fw[i][1][v] = t1[v] - t3[v]; + Fw[i][2][v] = t1[v] + t3[v]; + Fw[i][3][v] = t2[v] + t4[v]; + Fw[i][4][v] = t2[v] - t4[v]; + Fw[i][5][v] = T[i][3][v]; + } + } +} + +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[3][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float M_[3][16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v]; + t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v]; + t2[v] = t1[v] * 2.25f + Mw[5][i][j][v]; + + T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v]; + T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) + + 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]); + T[2][i][v] = t0[v] * 0.390625f + t2[v]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = t1[v] * 2.25f + T[i][5][v]; + + M_[0][v] = T[i][0][v] + t0[v] + t1[v]; + M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) + + 1.5f * (T[i][3][v] - T[i][4][v]); + M_[2][v] = t0[v] * 0.390625f + t2[v]; + } + +pragma_unroll + for (int k = 0; k < 3; k++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + M[i][k][j][v] = M_[k][v]; + } + } + } + } +} + +template +void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp, bool streamout = true) +{ + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + const int wp_max = inpw + l_pad; + const int hp_max = inph + t_pad; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + + array_offset_calculator input(inp, + jcp.mb, jcp.dimK/simd_w, inph, inpw, + simd_w); + array_offset_calculator output(tinp, + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((t_pad <= ydim) && (ydim < hp_max)) { + float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ; + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((l_pad <= xdim) && (xdim < wp_max)) { + float *pinp_i = pinp_j + (xdim - l_pad) * 16; + load_ps(I[j][i], pinp_i); + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_I_4x4_3x3(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(tile_block, j, i, + nb_tile_block_ur, 0, 0, + tile_block_ur, 0)), + Iw[j][i], streamout); + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + array_offset_calculator input(wp, + jcp.oc/jcp.oc_simd_block, + jcp.ic/jcp.ic_simd_block, + jcp.kh, jcp.kw, + simd_w, simd_w); + array_offset_calculator output(twp, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + simd_w, simd_w); + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + float *base_inp = is_fwd + ? &(input(0, 0, j, i, v1, 0)) + : &(input(0, 0, 2 - j, 2 - i, v1, 0)); + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + if (is_fwd) + F[j][i][v1][v2] = *(base_inp + v2); + else + F[j][i][v2][v1] = *(base_inp + v2); + } + } + } + } + + trans_W_4x4_3x3(Fw, F); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2]; + } + } + } + } +} + +template +void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias, + bool streamout = true) { + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + /* Prepare for PostOps */ + bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1; + + array_offset_calculator input(toutp, + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Ow[j][i][v] = input(tile_block, 0, + j, i, + nb_tile_block_ur, 0, + tile_block_ur, v); + } + } + } + + trans_O_4x4_3x3(Ow, O); + + for (int j = 0; j < tile_size; j++) { + int ydim = tj * tile_size + j; + if (ydim < outh) { + float *pout_j = pout_b + ydim * outw * simd_w; + for (int i = 0; i < tile_size; i++) { + int xdim = ti * tile_size + i; + if (xdim < outw) { + float *pout_i = pout_j + xdim * simd_w; + if (is_fwd) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + O[j][i][v] += with_bias ? bias[v] : 0.f; + O[j][i][v] = true + && with_relu_presum && O[j][i][v] < 0.f + ? O[j][i][v] + * jcp.eltwise.alpha + : O[j][i][v]; + } + } + if (with_sum) + accum_output(pout_i, O[j][i], streamout, + with_relu_postsum); + else + store_output(pout_i, O[j][i], streamout); + } + } + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *Iw_temp, + void (*transpose_4fma_ker)(float *, float *)) +{ + + const int ifwp = conv.iw + conv.l_pad; + const int ifhp = conv.ih + conv.t_pad; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator Iw_trans_temp(Iw_temp, + alpha, alpha, conv.tile_4fma, simd_w); + array_offset_calculator input(inp, + conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w); + array_offset_calculator output(tinp, + conv.nb_ic, alpha, alpha, + conv.tile_block, conv.ic_block, + conv.nb_tile_block_ur, conv.tile_block_ur, + conv.ic_simd_block * conv.tile_4fma); + + int tile_base_index = + image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding); + int tile_4fma = 0; + int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((conv.t_pad <= ydim) && ydim < ifhp) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((conv.l_pad <= xdim) && xdim < ifwp) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input(0, 0, + ydim - conv.t_pad, + xdim - conv.l_pad, v); + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + trans_I_4x4_3x3(Iw, I); + + if (ver_4fma) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, + tile_4fma, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = Iw[j][i][v]; + } + } + } + tile_4fma++; + if (tile_4fma == conv.tile_4fma) { + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + tile_4fma = 0; + tile_block_ur++; + } + } else { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + } + + if (tile_block_ur == conv.tile_block_ur) { + tile_block_ur = 0; + ++nb_tile_block_ur; + } + if (nb_tile_block_ur == conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } + + if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) { + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = 0; + } + } + } + } + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + } +} + +template +void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *dbias) +{ + + const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator input(inp, + conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block); + array_offset_calculator output(tinp, + conv.nb_oc, alpha, alpha, + conv.tile_block, conv.oc_block, + conv.nb_tile_block_ur, + conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block); + + int tile_base_index = image * total_tiles; + int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma); + int nb_tile_block_ur = + (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if (ydim < conv.oh) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if (xdim < conv.ow) { + float *input_base = &(input(0, 0, ydim, xdim, 0)); + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input_base[v]; + } + if (with_bias && j < tile_size && i < tile_size) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + dbias[v] += input_base[v]; + } + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_W_3x3_4x4_wu(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, + tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + array_offset_calculator input(twp, + conv.nb_ic, conv.nb_oc, + alpha, alpha, + conv.oc_block, conv.ic_block, + conv.ic_simd_block, conv.oc_simd_block); + array_offset_calculator output(wp, + conv.oc/simd_w, conv.ic/simd_w, + conv.kh, conv.kw, + conv.ic_simd_block, conv.oc_simd_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < conv.oc_simd_block; k++) { + Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k); + } + } + } + } + + trans_O_3x3_4x4_wu(Fw, F); + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + store_output(&(output(0, 0, j, i, v, 0)), + F[j][i][v], true); + } + } + } +} + +template +void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Note that jcp.with_eltwise is true for both fused conv+relu primitive + * and conv primitive with PostOps with relu before sum + * (PostOps relu after sum is handled later) */ + auto output_transform = jcp.with_bias + ? (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data + : output_transform_data) + : (jcp.with_sum + ? output_transform_data + : output_transform_data)) + : (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data + : output_transform_data) + : (jcp.with_sum + ? output_transform_data + : output_transform_data)); + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator U( + scratchpad.template get(key_wino_U), + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_cache_size ? true : false; + + const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0; + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout); + }); + + parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block, + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr); + }); + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block, + [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) { + + kernel_->gemm_loop_ker_first_iter( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, 0, 0, 0, 0))); + for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, + 0, 0, 0))); + } + + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block, + [&](int img, int M_blk1, int M_blk2) { + + const int M_blk = M_blk1 * jcp.dimM_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), + bias_ptr, output_is_aligned); + + }); + + } +} + +template struct _jit_avx512_common_convolution_winograd_t; +template struct _jit_avx512_common_convolution_winograd_t; + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias[oc] = padded_bias[oc]; + } +} + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const { + auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ? + diff_src_transform_bwd_weights : + diff_src_transform_bwd_weights; + auto diff_dst_transform_bwd_weights_ver = jcp.with_bias + ? diff_dst_transform_bwd_weights + : diff_dst_transform_bwd_weights; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights(ptr_diff_weights, + jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator diff_bias(pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : ptr_diff_bias, + jcp.oc/simd_w, simd_w); + + array_offset_calculator U( + scratchpad.get(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, jcp.oc_simd_block); + + array_offset_calculator M( + scratchpad.get(key_wino_M), + jcp.nb_oc, alpha, alpha, + jcp.tile_block, jcp.oc_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma, + jcp.oc_simd_block); + array_offset_calculator V( + scratchpad.get(key_wino_V), + jcp.nb_ic, alpha, alpha, + jcp.tile_block, jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block * jcp.tile_4fma); + + const int trans_buffer_size = alpha * alpha * jcp.tile_4fma + * jcp.ic_simd_block; + array_offset_calculator trans_buffer( + scratchpad.get(key_conv_tr_src), + nthreads, + trans_buffer_size); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), + nthreads, + jcp.oc); + +PRAGMA_OMP(parallel num_threads(nthreads)) + { + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + +PRAGMA_OMP(for nowait) + for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + diff_bias(bofm, v) = 0.0f; + } + } + + const int ithread = mkldnn_get_thread_num(); + + parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block, + [&](int img, int ifm1, int ifm2) { + float *transb = jcp.ver == ver_4fma + ? &(trans_buffer(ithread, 0)) + : NULL; + diff_src_transform_bwd_weights_ver(img, jcp, + &(src(img, ifm1 * jcp.ic_block + ifm2, + 0, 0, 0)), + &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)), + transb, + kernel_->transpose_4fma_ker); + }); + + parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block, + [&](int img, int ofm1, int ofm2) { + float *dbias = jcp.with_bias + ? &(diff_bias_prv(ithread, + simd_w * (ofm1 * jcp.oc_block + ofm2))) + : NULL; + diff_dst_transform_bwd_weights_ver(img, jcp, + &(diff_dst(img, ofm1 * jcp.oc_block + ofm2, + 0, 0, 0)), + &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)), + dbias); + }); + +PRAGMA_OMP(barrier) + + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) { + parallel_nd_in_omp(alpha, alpha, jcp.nb_oc, + [&](int oj, int oi, int ofm1) { + kernel_->gemm_loop_ker_first_iter( + (float *)&(U(ifm1, ofm1, oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, + 0, 0, 0, 0, 0))); + for (int tile_block = 1; tile_block < jcp.tile_block; + tile_block++) { + kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1, + oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, tile_block, + 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, tile_block, + 0, 0, 0, 0))); + } + }); + } + +PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2) { + diff_weights_transform_bwd_weights(jcp, + &(diff_weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), + &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0))); + }); + + if (jcp.with_bias) { +PRAGMA_OMP(for) + for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) { + for (int ithr = 0; ithr < nthreads; ithr++) { + float* base_bias_ptr = &(diff_bias(ofm1, 0)); + float* base_bias_prv_ptr = &(diff_bias_prv( + ithr * jcp.oc + ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < simd_w; ofm2++) { + base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2]; + } + } + } + } + } + + _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad); +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp new file mode 100644 index 0000000000..6c76f37c72 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp @@ -0,0 +1,318 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_common { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) { + const int nthr = mkldnn_get_max_threads(); + + size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr + * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block; + scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M); + + size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + + size_t padded_bias_sz = + jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0; + scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz); + } +} +} + +template +struct _jit_avx512_common_convolution_winograd_t { + _jit_avx512_common_convolution_winograd_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp); + } + + ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; } + + protected: + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_common_convolution_winograd_fwd_t + : _jit_avx512_common_convolution_winograd_t + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_fwd_t(){}; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_data_t + : _jit_avx512_common_convolution_winograd_t, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_bwd_data_t(){}; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, nullptr, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, + hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true), kernel_(nullptr) + { + kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + pd()->jcp_); + } + + ~jit_avx512_common_convolution_winograd_bwd_weights_t() + { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx)); + return status::success; + } + +private: + void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const; + void _maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_; +}; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]); +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]); +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]); +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]); +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]); + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp new file mode 100644 index 0000000000..d4a451c021 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp @@ -0,0 +1,853 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_lrn.hpp" + +#include "jit_generator.hpp" + +#define FWD_RBC 4 +#define BWD_RBC 3 + +#define XMM_SIZE (4*sizeof(float)) +#define ZMM_SIZE (vlen) +#define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE) +#define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE) +#define SRC_PREV_OFFSET (vlen - XMM_SIZE) + +#define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \ + statement;\ +} + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +enum params { vsize = 16, vlen = 64}; + +typedef struct { + const float *src; + float *dst, *ws0, *ws1; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *ws0, *ws1; + float *diff_src; +} jit_args_bwd_t; + +struct nChw16c_across { +/* version: + * -1: channels 0..15, + * 1: channels C-16 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 dst = r8; + Reg64 scratch0 = rdx; + Reg64 scratch1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm zalpha = zmm0; + Xmm xalpha = xmm0; + Zmm zk = zmm1; + Xmm xk = xmm1; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r9; + + int xsrc_prev = 2; + int zsrc = 7; + int xsrc_next = 3; + int zc = 7; + + int za = 2; + int zb = 3; + int zd = 5; + int ze = 6; + int zsum = 4; + int zdst = 2; + int zbase = 3; + int zsum2 = 5; + + prop_kind_t pk; + int use_h_parallelism; + + float alpha, k; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_fwd_t *); + void operator()(jit_args_fwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*FWD_RBC, + prf2_offt = 8*FWD_RBC + }; + + inline void compute_loop(int loop_size_param) + { + // loop_size - param for IRB_LOOP macro + int loop_size = FWD_RBC; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*3 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*7 + i); + }; + + if (!is_first && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen])); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen])); + } + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0, + (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0, + (irb + prf2_offt)*vlen))); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen))); + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1, + (irb + prf0_offt) * vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1, + (irb + prf2_offt) * vlen))); + } + + loop_size = loop_size_param; + if (loop_size == 0) + return; + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_prev), + ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET])); + } + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_next), + ptr[src + (irb + HW) * vlen])); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xsrc_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xsrc_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + + assert(zc == zsrc); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc))); + + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze))); + + IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha)); + + IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum))); + + IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2))); + + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + + if (pk != prop_kind::forward_inference) { + IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen), + zreg(irb, zsum))); + } + IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum))); + IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst))); + if (pk != prop_kind::forward_inference) { + /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */ + IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase))); + IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen), + zreg(irb, zsum))); + } + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + prop_kind_t prop_kind, + int use_h_parallel, + float A, + float K, + void *code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , pk(prop_kind) + , use_h_parallelism(use_h_parallel) + , alpha(A) + , k(K) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(dst, ptr[param + 8]); + if (pk != prop_kind::forward_inference) + { + mov(scratch0, ptr[param + 16]); + mov(scratch1, ptr[param + 24]); + } + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == -2; + is_single = J.version == 3; + + W = J.W; + HW = J.W*J.H; + int LSB = use_h_parallelism ? W : HW; + + sub(t, FWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(zalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(zk, xk); + + if (is_first || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2); + } + } + if (is_last || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xmm2); + } + } + + int LSREST = LSB % FWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(FWD_RBC); + + add(src, FWD_RBC*vlen); + add(dst, FWD_RBC*vlen); + if (pk != prop_kind::forward_inference) + { + add(scratch0, FWD_RBC*vlen); + add(scratch1, FWD_RBC*vlen); + } + + for(int irb = 0; irb < FWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST); + + add(t, FWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +status_t jit_avx512_common_lrn_fwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, data_d.data_type()) + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc()->prop_kind == forward_training) { + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + } + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float k = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk, + use_h_parallelism, alpha, k); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk, + use_h_parallelism, alpha, k); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k); + } +} + +jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, c16, C16, h, H); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, c16, C16, h, H); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 diffsrc = r8; + Reg64 diffdst = r9; + Reg64 workspace0 = rdx; + Reg64 workspace1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm znalphabeta = zmm0; + Xmm xnalphabeta = xmm0; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r10; + + int xws1_prev = 1; + int xdiffdst_prev = 2; + int zws1 = 1; + + int zsrc = 1; + int zdiffdst = 5; + int zdiffsrc = 6; + + int xws1_next = 1; + int xdiffdst_next = 3; + + int za = 1; + int zb = 2; + int zd = 3; + int ze = 4; + int zws0 = 2; + + float nalphabeta; + + int use_h_parallelism; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_bwd_t *); + void operator()(jit_args_bwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*BWD_RBC, + prf2_offt = 8*BWD_RBC + }; + + inline void compute_loop(int loop_size_param, int prefetchL1, + int prefetchL2) + { + // loop_size - param for IRB_LOOP macro + int loop_size = loop_size_param; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*6 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*6 + i); + }; + +// ---- prefetching ------------------------------------------- + if (!is_first && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + - 2 * HW) * vlen])); + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + - HW) * vlen])); + } + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen])); + + if (!is_last && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + + 2 * HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt + + 2 * HW) * vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + + HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt + + HW) * vlen])); + } + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen])); +// ----------------------------------------------------------- + + if (loop_size_param == 0) + return; + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb + - 2 * HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb + - HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev), + xreg(irb, xws1_prev))); + } + + IRB_LOOP(vmovups(zreg(irb, zws1), + EVEX_compress_addr(workspace1, irb*vlen))); + IRB_LOOP(vmovups(zreg(irb, zdiffdst), + EVEX_compress_addr(diffdst, irb*vlen))); + IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst), + zreg(irb, zws1))); + + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb + + 2 * HW) * vlen])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb + + HW) * vlen])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next), + xreg(irb, xws1_next))); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xdiffdst_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zdiffsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xdiffdst_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, za))); + assert(zsrc == za); + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zb))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zd))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, ze))); + IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta)); + + IRB_LOOP(vmovups(zreg(irb, zws0), + EVEX_compress_addr(workspace0, irb*vlen))); + IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst), + zreg(irb, zws0))); + IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc), + zreg(irb, zdiffdst))); + + Label unaligned_store, end_store; + test(diffsrc, vlen - 1); + jnz(unaligned_store, T_NEAR); + IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + jmp(end_store, T_NEAR); + L(unaligned_store); { + IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + } + L(end_store); + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2*A*B) + , use_h_parallelism(use_h_parallel) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(diffdst, ptr[param + 8]); + mov(workspace0, ptr[param + 16]); + mov(workspace1, ptr[param + 24]); + mov(diffsrc, ptr[param + 32]); + + W = J.W; + HW = J.H*J.W; + int LSB = this->use_h_parallelism ? W : HW; + + sub(t, BWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(znalphabeta, xnalphabeta); + + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == +2; + is_single = J.version == 3; + + if (is_first || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1); + } + } + if (is_last || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1); + } + } + + int LSREST = LSB % BWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(BWD_RBC, 1, 1); + + add(src, BWD_RBC*vlen); + add(diffsrc, BWD_RBC*vlen); + add(diffdst, BWD_RBC*vlen); + add(workspace0, BWD_RBC*vlen); + add(workspace1, BWD_RBC*vlen); + + for(int irb = 0; irb < BWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1); + + add(t, BWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } + +}; + +status_t jit_avx512_common_lrn_bwd_t::pd_t::init() { + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float beta = pd()->desc()->lrn_beta; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), + alpha, beta, use_h_parallelism); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), + alpha, beta, use_h_parallelism); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism); + } +} + +jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, h, H, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, h, H, c16, C16); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp new file mode 100644 index 0000000000..37fbb9b3e5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_LRN_HPP +#define CPU_JIT_AVX512_COMMON_LRN_HPP + +#include "c_types_map.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_fwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_fwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +struct jit_avx512_common_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_bwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_bwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp new file mode 100644 index 0000000000..c58d3fa0a6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp @@ -0,0 +1,1103 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_2x3.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = + reinterpret_cast(const_cast(getCode())); + } + + void generate(); + + Zmm vreg_inp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(15 - i); + } + + Zmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + +}; + +void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() { + Label ic_block_label; + + const int load_block = 16; + int out_offset = 0, inp_offset = 0; + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, src); + READ_PARAM(reg_aux_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + } + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + for (int y = 0; y < jcp.alpha; y++) { + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm = vreg_inp(y * jcp.alpha + x); + + vxorps(zmm, zmm, zmm); + kandw(r_mask, y_mask, x_mask(x)); + inp_offset = sizeof(float) + * ((-jcp.t_pad + y) * jcp.iw * load_block + + (-jcp.l_pad + x) * load_block); + vmovups(zmm | r_mask, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + } + for (int y = 0; y < jcp.alpha; y++) { + vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0), + vreg_inp(y * jcp.alpha + 2)); + vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 2)); + vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2), + vreg_inp(y * jcp.alpha + 1)); + vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 3)); + } + for (int x = 0; x < jcp.alpha; x++) { + vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0), + vreg_tmp(x + jcp.alpha * 2)); + vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 2)); + vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2), + vreg_tmp(x + jcp.alpha * 1)); + vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 3)); + } + + for (int i = 0; i < 16; i++) { + out_offset = sizeof(float) * (jcp.inp_stride * i); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), + vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block); + add(reg_aux_ptr_dst, sizeof(float) * load_block); + } + dec(reg_ic_block); + cmp(reg_ic_block, 0); + jg(ic_block_label, T_NEAR); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast( + const_cast(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id_reg_stg < jcp.alpha * jcp.alpha + 8); + return Zmm(31 - id_reg_stg); + } + + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id_reg_out < jcp.alpha * jcp.alpha + 12); + return Zmm(31 - id_reg_out); + } + + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_prev_dst = Zmm(0); + Zmm vreg_bias = Zmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1); + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() { + Label oc_block_label; + + const int load_block = 16; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for (int i = 0; i < 16; i++) { + int internal_offset = sizeof(float) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for (int y = 0; y < jcp.alpha; y++) { + vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1)); + vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2)); + + vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2)); + vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3)); + } + for (int x = 0; x < jcp.m; x++) { + vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1)); + vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2)); + + vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2)); + vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3)); + } + + + if (jcp.with_bias) { + auto bias_addr = ptr [ reg_ptr_bias ]; + vmovups(vreg_bias, bias_addr); + } + for (int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]); + for (int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = sizeof(float) * + (y * jcp.ow * jcp.oc_block + x * jcp.oc_block); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + + if (maybe_relu(0)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + if (p_sum_scale) { // post_op: sum + vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + vmovups(vreg_prev_dst | r_mask, addr); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + + vmovups(addr, zmm | r_mask); + } + } + }; + + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, wino_dst); + READ_PARAM(reg_aux_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha * jcp.alpha; i++) + vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i)); + + for (int i = 0; i < jcp.alpha; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + + int oc_blocks = 1; + oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); + { + loop_body(); + add(reg_aux_ptr_src, sizeof(float) * load_block); + add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, jcp.typesize_bia * load_block); + } + dec(reg_oc_block); + cmp(reg_oc_block, 0); + jg(oc_block_label, T_NEAR); + + sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = reinterpret_cast( + const_cast(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, + memory_desc_t& expect_wei_md); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert (31 - jcp.n2_block * jcp.m_block - i > 1); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + + Reg64 reg_mb = rbx; + Reg64 reg_nnb = rdx; + Reg64 reg_K = rsi; + +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + preamble(); +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); +#undef READ_PARAM + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + + mov(reg_K, jcp.k_chunks); + L(K_loop_label); { + int wei_offset = 0; + for (int _i = 0; _i < jcp.k2_block; _i++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + if (jcp.small_mb) { + int wei_offset = sizeof(float) + * ((nb2 * jcp.nb_ic * jcp.ic_block + * jcp.oc_block) + + _i * jcp.oc_block); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } else { + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, + sizeof(float) * wei_offset)); + wei_offset += jcp.oc_block; + } + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = sizeof(float) * (m * jcp.K + _i); + if (jcp.n2_block > 1) { + vbroadcastss(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2), + vreg_src); + } else { + vfmadd231ps(vreg_out(0, m), vreg_wei(0), + EVEX_compress_addr(reg_aux_src2, inp_offset, true)); + } + } + } + add(reg_aux_src2, sizeof(float) * jcp.ic_block); + if (jcp.small_mb) + add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block); + else + add(reg_aux_wei2, + sizeof(float) * jcp.k2_block * jcp.n2_block + * jcp.oc_block); + } + dec(reg_K); + cmp(reg_K, 0); + jg(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = sizeof(float) * + (m * jcp.N + nb2 * jcp.oc_block); + vmovups(EVEX_compress_addr(reg_aux_dst2,offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K); + add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N); + } + dec(reg_mb); + cmp(reg_mb, 0); + jg(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block); + add(reg_aux_wei, + sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block + * jcp.oc_block); + + dec(reg_nnb); + cmp(reg_nnb, 0); + jg(nnb_loop_label, T_NEAR); + } + postamble(); +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + return jcp.mb >= 4; +} +} + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &wei_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, memory_desc_t &expect_wei_md) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + int simdw = 16; + + format_tag_t dat_tag = format_tag::nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simdw); + jcp.ic = rnd_up(jcp.ic, simdw); + } + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core))) + return status::unimplemented; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (src_d.data_type() != data_type::f32) + return status::unimplemented; + if (wei_d.data_type() != data_type::f32) + return status::unimplemented; + if (dst_d.data_type() != data_type::f32) + return status::unimplemented; + + jcp.ic_block = simdw; + jcp.oc_block = simdw; + + bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad + && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0 + && jcp.l_pad < 2 && jcp.l_pad >= 0; + if (!ok) + return status::unimplemented; + + const int L2_cap = get_cache_size(2, true) / sizeof(float); + const int L3_capacity = get_cache_size(3, false) / sizeof(float); + int a = jcp.alpha; + int aa = a * a; + int mb = jcp.mb; + int ic = jcp.ic; + int oc = jcp.oc; + int ih = jcp.ih; + int iw = jcp.iw; + auto wei_sz = (float)aa * ic * oc; + auto inp_sz = (float)mb * ih * iw * ic; + auto sp_sz = (float)mb * ih * iw; + + /* Heuristics here. Numbers '28','196' is an observation from data. */ + if (wei_sz / inp_sz > 5) + jcp.small_mb = true; + else + jcp.small_mb = false; + + if (mb > nstl::min(jcp.nthr, 28) + || (!jcp.small_mb + && (wei_sz >= 0.9f * L2_cap + || inp_sz > L2_cap * jcp.nthr + L3_capacity)) + || (jcp.small_mb && sp_sz > 196)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + const int skx_free_regs = 30; + + auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block, + int &n2_block, float ®_eff) { + M = (xb * yb) / jcp.alpha; + int max_m_block = m_block = nstl::min(M, skx_free_regs); + int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs); + reg_eff = 0; + for (int im = max_m_block; im > 0; im--) { + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = in2 * im + in2; + float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f; + if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs + || cur_reg_eff <= reg_eff) + continue; + reg_eff = cur_reg_eff; + m_block = im; + n2_block = in2; + } + } + }; + + int oh = jcp.oh; + int ow = jcp.ow; + int nb_oc = jcp.nb_oc; + int Z = ic + oc; + int Y = ic * oc; + const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float); + + /* Selecting xb and yb blocking */ + int min_yb = jcp.alpha; + int min_xb = jcp.alpha; + int max_yb = nstl::max(min_yb, rnd_up(ih, 2)); + int max_xb = nstl::max(min_xb, rnd_up(iw, 2)); + float best_eff = 0.f; + for (int ix = max_xb; ix >= min_xb; ix -= 2) { + if (rnd_up(ow, ix) < iw - 2) + continue; + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + if (rnd_up(oh, iy) < ih - 2) + continue; + int ex_y = rnd_up(oh, iy); + int ex_x = rnd_up(ow, ix); + float work_eff = (float)(ih * iw) / (ex_y * ex_x); + + int M, m_block, n2_b; + float reg_eff, thr_eff, par_eff, mem_eff, req_mem; + + find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff); + + /* outer parallelization */ + int nblocks = mb * div_up(ih, iy) * div_up(iw, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + + mem_eff = 1.f; + req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y; + if (req_mem > L2_cap / 2) { + if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7) + mem_eff /= (n2_b + 1) / 2.f; + else + mem_eff /= (n2_b + 1) / 3.f; + } + + float outer_eff = thr_eff + work_eff + reg_eff + mem_eff; + + /* inner parallelization */ + int bsz = iy * ix / a; + int gemmw = aa * (nb_oc / n2_b); + int bsz_r = rnd_up(bsz, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y); + + req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr)); + int oc_per_thr = + nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr)); + req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z; + if (req_mem > L2_cap) + mem_eff = 0.1f; + par_eff = 1 / (2.f * nblocks); + + float inner_eff = thr_eff + work_eff + mem_eff + par_eff; + + float eff = jcp.small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.M = M; + jcp.m_block = m_block; + jcp.n2_block = n2_b; + } + } + } + + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + assert(jcp.M % jcp.m_block == 0); + assert(jcp.nb_oc % jcp.n2_block == 0); + + jcp.n_chunks = jcp.nb_oc / jcp.n2_block; + jcp.k2_block = jcp.ic_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format + = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = 1.f; + size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::pd_t::jit_conf(memory_desc_t& expect_wei_md) { + return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr(), expect_wei_md); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t:: + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + const size_t wino_size_offset = + (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb); + const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16; + const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get(key_wino_V); + auto ptr_M = scratchpad.get(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = ptr_V + size_wino_src * ithr; + auto wino_dst = ptr_M + size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t + ::call_params_t(); + auto dst_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t + ::call_params_t(); + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw + * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow + * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get(key_wino_V); + auto ptr_M = scratchpad.get(key_wino_M); + + for (int mb = 0; mb < jcp.mb; mb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t :: + call_params_t(); + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = ptr_V + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + gemm_p.src = ptr_V + jcp.inp_stride * tile_ij; + gemm_p.dst = ptr_M + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t :: + call_params_t(); + + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = ptr_M + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp new file mode 100644 index 0000000000..7e38b07f5a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp @@ -0,0 +1,144 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t; +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t; +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t; + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_2x3_fwd_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + memory_desc_t expect_wei_md = *weights_md(); + status_t jit_conf_result = jit_conf(expect_wei_md); + if (jit_conf_result != status::success) return jit_conf_result; + set_default_alg_kind(alg_kind::convolution_winograd); + + if (weights_md_.format_kind == format_kind::any) + weights_md_ = expect_wei_md; + if (weights_md_ != expect_wei_md) + return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(memory_desc_t& expect_wei_md); + + void init_scratchpad() { + using namespace memory_tracking::names; + + auto scratchpad = scratchpad_registry().registrar(); + + int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb; + + size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_4K); + + size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_4K); + + if (wants_padded_bias()) { + assert(jcp_.ngroups == 1); + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc); + } + } + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nChw16c, any, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd); + ~jit_avx512_core_fp32_wino_conv_2x3_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto wei = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bia = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + if (pd()->jcp_.small_mb) + execute_forward_small_mb(src, wei, bia, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, wei, bia, dst, this->scratchpad(ctx)); + + return status::success; + } + +private: + void execute_forward_small_mb(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t *kernel_; + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t *src_trans_; + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp new file mode 100644 index 0000000000..96325e3ade --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp @@ -0,0 +1,1020 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t +::weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const +{ + float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f, + 1.13777777777778f, 0.430252100840336f, 0.179271708683473f}; + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + float T[alpha][3][simd_w]; + auto p = jit_wino_transform_call_s(); + + p.src = wp; + p.dst = twp; + p.G = G; + p.M = F; + p.Mw = Fw; + p.T = T; + + kernel_->weights_transform_data_ker(&p); +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::output_transform_data +(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = pout_b; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->output_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t +::output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = outp; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block) + * outh * outw * jcp.dimM_simd_block; + + kernel_->output_transform_data_ker(&p); + + tile_index++; + } + } +} + + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t + ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + + p.src = inp; + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t + ::input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + + array_offset_calculator input(inp, + jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); + array_offset_calculator output(tinp, + alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + auto p = jit_wino_transform_call_s(); + + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + float *pinp_b = &(input(img, 0, 0, 0, 0)); + + p.src = pinp_b; + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_index++; + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get(key_wino_U); + + array_offset_calculator U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, + 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + }); + + if (jcp.prop_kind != prop_kind::forward_inference) { + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), + (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, + [&](int N_blk1, int oj, int oi, int M_blk1) { + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; + K_blk1++) + for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++) + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, 0, 0, 0)), K_blk1); + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block), + [&](int img, int M_blk1, int M_blk2) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + output_transform_data(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), bias_ptr); + }); + + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_SGD( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get(key_wino_U); + + array_offset_calculator U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + 0, jcp.dimM_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + 0, alpha, alpha, jcp.dimN_block, + jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + if (jcp.prop_kind != prop_kind::forward_inference) { + + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.tile_block, [&](int tile_block) { + int ithr = mkldnn_get_thread_num(); + + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) { + + input_transform_tileblock_data( + tile_block, jcp, + &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + } + } + + for (int oj = 0; oj < alpha; oj++) { + for (int oi = 0; oi < alpha; oi++) { + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) + for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) + kernel_->gemm_loop_ker( + (float *)&(M(ithr, M_blk1, oj, oi, + N_blk, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, K_blk1, + 0, 0, 0, 0)), + (const float *)&(V(ithr, oj, oi, + N_blk, K_blk1, 0, 0, 0)), K_blk1); + } + } + + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) { + for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; + M_blk2++) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform_tileblock_data(tile_block, jcp, p_ops, + &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(0, M_blk, 0, 0, 0)), bias_ptr); + } + } + }); +} + +template struct _jit_avx512_core_fp32_wino_conv_4x3_t; +template struct _jit_avx512_core_fp32_wino_conv_4x3_t; + +namespace { + +void subarray_sum(size_t num_arrs, float *output, size_t nelems, + float *input_ptrs[], size_t input_starts[], size_t input_ends[]) { + using namespace nstl; + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const int ithr = mkldnn_get_thread_num(); + const int nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} + +const int max_threads_number = 1024; + +// Sum to the first buffer array +void array_sum(size_t num_arrs, float *output, + size_t nelems, float *input_ptrs[], bool reduce_to_first = true) { + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const size_t ithr = mkldnn_get_thread_num(); + const size_t nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} +} //bwdw namespace + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_SDGtWo(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights(ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator Us(scratchpad.get(key_wino_U), + 0, alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc + * jcp.ic / jcp.nb_ic; + array_offset_calculatordiff_weights_prv( + scratchpad.get(key_wino_U) + U_sz, + 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator M(scratchpad.get(key_wino_M), + 0, alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator V(scratchpad.get(key_wino_V), + 0, alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); + + auto trans_ker_p = jit_wino_transform_call_s(); + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f, + 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f, + 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + +PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc / simd_w, + [&](int ithr, int ofm){ + float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + pdbias[v] = 0.0f; + } + }); + } + + int ithr = mkldnn_get_thread_num(); + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) { + int first_tblk = 0; +PRAGMA_OMP(for) + for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) { + int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur; + int img = tile_index / (jcp.itiles * jcp.jtiles); + trans_ker_p.ti = tile_index % jcp.itiles; + trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles; + trans_ker_p.M = I; + trans_ker_p.T = T; + trans_ker_p.G = G_I_3x3_4x4; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + } + + for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) { + trans_ker_p.G = G_W_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias && ifm1 == 0) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + } + + for (int oj = 0; oj < alpha; ++oj) { + for (int oi = 0; oi < alpha; ++oi) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(M(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(V(ithr, oj, oi, 0, 0, 0, 0))); + } + } + trans_ker_p.G = G_O_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block + + ofm3; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(Us(ithr, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights_prv(ithr, + ofm, ifm, 0, 0, 0, 0)); + if (first_tblk == 0) { + kernel_->diff_weights_transform(&trans_ker_p); + } else { + kernel_->diff_weights_transform_accum(&trans_ker_p); + } + } + } + } + } + ++first_tblk; + } + } +} + + // Reduce diff-weights + { + float *output = ptr_diff_weights; + float *input_base = scratchpad.get(key_wino_U) + U_sz; + int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + nelems * i; + } + array_sum(nthreads, output, nelems, input_ptrs, false); + + if (jcp.with_bias) { + output = ptr_diff_bias; + input_base = scratchpad.get(key_conv_bia_reduction); + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + jcp.oc * i; + } + array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs, + false); + } + } +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_S_D_Giot_W(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights((float *)ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator diff_bias((float *)ptr_diff_bias, jcp.oc); + + array_offset_calculator U(scratchpad.get(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_size = jcp.oc * jcp.ic * alpha * alpha; + array_offset_calculator Us( + scratchpad.get(key_wino_U) + U_size, + 0, jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator M(scratchpad.get(key_wino_M), + jcp.nb_oc, + jcp.tile_block, + alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur , + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator V(scratchpad.get(key_wino_V), + jcp.nb_ic, + jcp.tile_block, + alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); + + size_t input_starts[max_threads_number] = {0}; + size_t input_ends[max_threads_number] = {0}; + size_t first_tblk = 0; + + auto trans_ker_p = jit_wino_transform_call_s(); + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, + 0.119514472455649f, 0.430252100840336f, 0.168067226890756f, + 0.179271708683473f, 0.403361344537815f, 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + +PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + } + + trans_ker_p.G = G_I_3x3_4x4; + trans_ker_p.M = I; + trans_ker_p.T = T; + + parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb, + [&](int ifm1, int ifm2, int img){ + size_t ifm = ifm1 * jcp.ic_block + ifm2; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + }); + + int ithr = mkldnn_get_thread_num(); + trans_ker_p.G = G_W_3x3_4x4; + parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb, + [&](int ofm1, int ofm2, int img){ + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + }); + + PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block, + [&](int ifm1, int ofm1, int oj, int oi, int tblk1){ + if (first_tblk == 0) { + input_starts[ithr] = + (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, + 0, 0)) + - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0, + 0, 0, 0)); + input_ends[ithr] = input_starts[ithr] + + jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + else if (tblk1 == 0) { + input_ends[ithr] += jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + + if (first_tblk == 0 || tblk1 == 0) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } else { + kernel_->gemm_loop_ker( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } + ++first_tblk; + }); +} + + // Reduce diff-weights + { + float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0)); + size_t nelems = jcp.ic * jcp.oc * alpha * alpha; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) + input_ptrs[i] = output + nelems * (i + 1); + subarray_sum(nthreads, output, nelems, input_ptrs, + input_starts, input_ends); + } + + trans_ker_p.G = G_O_3x3_4x4; +PRAGMA_OMP(parallel firstprivate(trans_ker_p)) + { + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){ + int ofm = (ofm1 * jcp.oc_block + ofm2) + * jcp.oc_reg_block + ofm3; + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm, + 0, 0, 0, 0)); + kernel_->diff_weights_transform(&trans_ker_p); + }); + } + + if (jcp.with_bias) { + parallel_nd(jcp.oc / simd_w, [&](int ofm1) { + float* pbias = &(diff_bias(ofm1 * simd_w)); + float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w)); + + const int blk_sz = ofm1 == jcp.oc / simd_w - 1 + ? jcp.oc_without_padding - ofm1 * simd_w : simd_w; + + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] = pbias_prv[ofm2]; + } + + for (int ithr = 1; ithr < nthreads; ++ithr) { + pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] += pbias_prv[ofm2]; + } + } + }); + } +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp new file mode 100644 index 0000000000..f1a56aac70 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp @@ -0,0 +1,386 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_core { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace utils; + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles + * jcp.jtiles; + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles + * jcp.jtiles; + + switch (jcp.sched_policy) { + case WSCHED_DATA_W_SGD: + V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.ic; + M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.oc; + break; + case WSCHED_WEI_SDGtWo: + U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc + * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw); + M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.oc / jcp.nb_oc); + V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.ic / jcp.nb_ic); + break; + case WSCHED_WEI_S_D_Giot_W: + U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc; + M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles; + V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles; + break; + default: break; + } + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) { + size_t br_sz = (size_t)jcp.nthr * jcp.oc; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + } +} +} + +template +struct _jit_avx512_core_fp32_wino_conv_4x3_t { + + _jit_avx512_core_fp32_wino_conv_4x3_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp); + } + + ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; } + + protected: + void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const; + void input_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void output_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, + float *bias) const; + void output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const; + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_data_W_SGD(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t + : _jit_avx512_core_fp32_wino_conv_4x3_t + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = desc()->prop_kind == prop_kind::forward_training + ? (with_groups() ? gOIhw16i16o : OIhw16i16o) : any; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + default: + break; + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t + : _jit_avx512_core_fp32_wino_conv_4x3_t, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_data + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + ::init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + default: + break; + } + + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) + return status::unimplemented; + + status_t status = + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , kernel_(nullptr) + { + kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + pd()->jcp_); + } + + ~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t() + { + delete kernel_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + switch (kernel_->jcp.sched_policy) { + case WSCHED_WEI_SDGtWo: + _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + case WSCHED_WEI_S_D_Giot_W: + _execute_backward_weights_S_D_Giot_W(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + default: + assert(kernel_->jcp.sched_policy != WSCHED_INVALID); + break; + } + return status::success; + } + +private: + void _execute_backward_weights_SDGtWo(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_backward_weights_S_D_Giot_W(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp new file mode 100644 index 0000000000..0d64a2d13a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp @@ -0,0 +1,2596 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + /* Determines if current winograd implementation is faster than direct. + Following conditions are empirical and based on performance data */ + unsigned int ncores_per_socket = + cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + unsigned int nthreads = mkldnn_get_max_threads(); + + if (jcp.prop_kind == prop_kind::forward_inference) { + return jcp.mb >= 4; + } else if (nthreads > ncores_per_socket) { + double src_dst_transforms_per_core = alpha * alpha + * (jcp.ic + jcp.oc) + * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size) + * ((jcp.ow + tile_size - 1) / tile_size) + * sizeof(float) / 1024. / 1024. / nthreads; + double wei_transform = alpha * alpha + * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.; + + if (jcp.prop_kind == prop_kind::backward_weights) { + if (src_dst_transforms_per_core < 0.3 + || (src_dst_transforms_per_core <= 28 && wei_transform < 4)) + return false; + else + return true; + } else { + if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02) + return false; + } + } + + return jcp.mb > 8; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: cache_latency = 250; break; + case L3: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, + int dimN_block, float C2_min, float C2_max) { + float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic) + * dimN_block * jcp.dimN_reg_block + + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float); + float L2_lb = C2_min * L2_cache_size; + float L2_ub = C2_max * L2_cache_size; + return (block_size > L2_lb && block_size < L2_ub); +} + +bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block, + int dimM_block, float C1_min, float C1_max) { + float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block + * jcp.dimK_reg_block * jcp.dimM_reg_block + + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block + + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block) + * (float)sizeof(float); + float L1_lb = C1_min * L1_cache_size; + float L1_ub = C1_max * L1_cache_size; + return (gemm_block_size > L1_lb && gemm_block_size < L1_ub); +} +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block + * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block, + int dimM_simd_block, float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} + +bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block, + int dimN_block, int dimN_reg_block, int dimK, float C1, float C2) +{ + float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK + * (float)sizeof(float); + float B_size = dimN_block * dimN_reg_block * dimK + * (float)sizeof(float); + return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate() +{ + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block; + // dimM_reg_block++) // unrolled + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) // unrolled + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][dimM_reg_block][tile] += + // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block] + // * broadcast(B[dimK_block][tile][dimK_reg_block]); + // Notes: + // jcp.kernel_kind defines embedded or explicit broadcast + // dimM_reg_block=1 for embedded bcast kernel + + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + auto zmm_srcB = [=](int tile) { + int idx = 1 + tile; + assert(idx < 1 + jcp.dimN_reg_block); + return Xbyak::Zmm(idx); + }; + auto zmm_dstC = [=](int dimM_reg_block, int tile) { + int idx{0}; + if (jcp.kernel_kind == embd_bcast) + idx = 1 + tile; + else + idx = 1 + jcp.dimN_reg_block + + dimM_reg_block * jcp.dimN_reg_block + tile; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + + auto prepare_output = [=]() { + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block, tile); + vpxord(zmm, zmm, zmm); + } + } + }; + auto store_output = [=](bool output_is_aligned) { + Label save; + cmp(reg_is_beta_zero, 0); + je(save, T_NEAR); + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset)); + } + } + + L(save); + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + + // In W_SGD, output will be reused. + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && jcp.sched_policy == WSCHED_DATA_W_S_G_D + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm); + else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + + prepare_output(); + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block ++) { + + if (jcp.kernel_kind == expl_bcast) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + vbroadcastss(zmm_srcB(tile), + ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]); + } + } + + /* Performing the fmas */ + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + + vmovups(zmm_srcA(), + zword[reg_srcA + + jcp.dimK_reg_block * jcp.dimK_block * 64 + * dimM_reg_block + + dimK_reg_block * 64] + ); + + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + zmm_srcB(tile)); + else + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, true)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64); + if (jcp.kernel_kind == expl_bcast) { + add(reg_srcA, + (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64 + * jcp.dimK_block); + } + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::weights_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int kh = jcp.kh; + int kw = jcp.kw; + + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_zero = Xbyak::Zmm(30); + + auto zmm_M = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_MT = [=](int i) { + return Xbyak::Zmm(i + simd_w); + }; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_F = [=](int i) { + return Xbyak::Zmm(alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(alpha + 3 + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(2 * alpha + 3 + i); + }; + + auto zmm_load = [=](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(wreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < alpha; i++) { + vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]); + } + vpxord(zmm_zero, zmm_zero, zmm_zero); + }; + + auto trans16x16 = [=]() { + for (int i = 0; i < simd_w; i+=2 ) { + vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]); + vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]); + vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1)); + vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1)); + } + for (int i = 0; i < simd_w; i+=4 ) { + vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2)); + vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2)); + vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3)); + vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3)); + } + for (int i = 0; i < simd_w; i += 8) { + vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88); + vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88); + vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88); + vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88); + vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd); + vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd); + vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd); + vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd); + } + { + int i = 0; + int mask = 0x88; + vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0)); + vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1)); + vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2)); + vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3)); + vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4)); + vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5)); + vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6)); + vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7)); + mask = 0xdd; + vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8)); + vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9)); + vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10)); + vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11)); + vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12)); + vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13)); + vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14)); + vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15)); + } + }; + + auto load_src = [=]() { + mov(wreg_src, ptr[param1 + GET_OFF(src)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + if (is_fwd) { + for (int v1 = 0; v1 < simd_w; v1++) { + int offset_src = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + vmovups(zmm_temp, ptr[wreg_src + offset_src]); + vmovups(ptr[wreg_F + offset_F], zmm_temp); + } + } else { + int offset_src = ((2 - j) * kw * simd_w * simd_w + + (2 - i) * simd_w * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w) * typesize; + lea(wreg_M, ptr[wreg_src + offset_src]); + lea(wreg_MT, ptr[wreg_F + offset_F]); + trans16x16(); + } + } + } + }; + + auto store_dst = [=]() { + mov(wreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + mov(wreg_dst_aux, wreg_dst); + mov(wreg_Fw_aux, wreg_Fw); + + int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimK_block * simd_w * simd_w; + + L(Loop_j); + { + for (int i = 0; i < alpha; i++) { + // touch pages + vmovups(zmm_load(0), ptr[wreg_Fw_aux + + (i * simd_w * simd_w) * typesize]); + mov(wreg_dst_idx, i * dim5 * typesize); + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0)); + } + for (int i = 0; i < alpha; i++) { + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_Fw = (i * simd_w * simd_w + v1 * simd_w) + * typesize; + vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]); + } + mov(wreg_dst_idx, i * dim5 * typesize); + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_dst = v1 * simd_w * typesize; + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst], + zmm_load(v1)); + } + } + add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize); + add(wreg_dst_aux, alpha * dim5 * typesize); + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, alpha); + jl(Loop_j, T_NEAR); + } + }; + + auto trans_W_4x4_3x3 = [=]() { + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmovups(dst, a); + vfmadd231ps(dst, b, c); + }; + auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, b, c); + vsubps(dst, a, zmm_temp); + }; + auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vsubps(dst, zmm_zero, a); + vfnmadd231ps(dst, b, c); + }; + + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + mov(wreg_T, ptr[param1 + GET_OFF(T)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + L(Loop_j); + mov(wreg_F_aux, wreg_F); + mov(wreg_Fw_aux, wreg_Fw); + mov(wreg_temp, wreg_cnt_j); + shl(wreg_temp, 4 + 2); + lea(wreg_F_aux, ptr[wreg_F + wreg_temp]); + lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]); + + for (int i = 0; i < 3; i++) { + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w + * simd_w + i * simd_w * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_F(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0)); + + vmulps(zmm_T(0), zmm_G(3), zmm_F(0)); + fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1)); + fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1)); + vmovaps(zmm_T(5), zmm_F(2)); + + for (int idx = 0; idx < 6; idx ++) { + vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w) + * typesize], zmm_T(idx)); + } + } + for (int i = 0; i < 6; i++) { + + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_T(idx), ptr[wreg_T + + (i * 3 * simd_w + idx * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_T(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0)); + + vmulps(zmm_F(0), zmm_G(3), zmm_T(0)); + fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1)); + fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1)); + vmovaps(zmm_F(5), zmm_T(2)); + + for (int l = 0; l < 6; l++) { + vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w + + l * simd_w * simd_w) * typesize], zmm_F(l)); + } + } + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, 16); + jl(Loop_j, T_NEAR); + }; + + auto inner_loops = [=]() { + load_src(); + init_G(); + trans_W_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::output_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + bool with_bias = jcp.with_bias; + bool with_relu = jcp.with_eltwise; + bool with_relu_postsum = jcp.with_relu_postsum; + bool with_sum = jcp.with_sum; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_O = [=](int i) { + return Xbyak::Zmm(1 + alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + 2 * alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + 3 * alpha + i); + }; + + auto init_G = [=]() { + mov(oreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < 6; i++) { + vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_src, ptr[param1 + GET_OFF(src)]); + + mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur, + (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_nb_tile_block_ur); + + mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(oreg_tile_block_ur, oreg_tile_block_ur, + jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block_ur); + + if (not_tiled) { + mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(oreg_tile_block, oreg_tile_block, + jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block + * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block); + } + + int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int j_base_offset = j * alpha * last4dim; + int i_base_offset = i * last4dim; + vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]); + vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_dst = [=]() { + vpxord(zmm_zero, zmm_zero, zmm_zero); + mov(oreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(oreg_ydim, 2); // tj * tile_size (==4) + mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(oreg_xdim, 2); // ti * tilesize (==4) + + if (with_bias) + mov(oreg_bias, ptr[param1 + GET_OFF(bias)]); + + auto store_one = [=](int j, int i, bool is_aligned) { + auto zmm_O = Xbyak::Zmm(31); + auto zmm_relu_ns = Xbyak::Zmm(30); + auto xmm_relu_ns = Xbyak::Xmm(30); + int offset = (j * tile_size * simd_w + i * simd_w) * typesize; + + vmovups(zmm_O, ptr[oreg_O + offset]); + if (is_fwd) { + if (with_bias) { + vaddps(zmm_O, zmm_O, ptr[oreg_bias]); + } + if (with_relu) { + if (jcp.eltwise.alpha == 0) { + vmaxps(zmm_O, zmm_O, zmm_zero); + } else { + Opmask kmask = Opmask(7); + mov(imm_addr64, float2int(jcp.eltwise.alpha)); + vmovq(xmm_relu_ns, imm_addr64); + vbroadcastss(zmm_relu_ns, xmm_relu_ns); + vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os); + vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns); + } + } + } + if (with_sum) { + vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]); + if (with_relu_postsum) // orig: with_relu_postsum + vmaxps(zmm_O, zmm_O, zmm_zero); + } + if (is_aligned) + vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O); + else + vmovups(ptr[oreg_out_j + oreg_temp], zmm_O); + }; + + auto i_loop = [=](int j, bool is_aligned) { + for (int i = 0; i < tile_size; i++) { + Label next; + mov(oreg_temp, oreg_xdim); + add(oreg_temp, i); + cmp(oreg_temp, outw); + jge(next, T_NEAR); + shl(oreg_temp, 4 + 2); // * 16 * 4 + + store_one(j, i, is_aligned); + + L(next); + } + }; + + + for (int j = 0; j < tile_size; j++) { + Label next, unaligned; + mov(oreg_temp, oreg_ydim); + add(oreg_temp, j); + cmp(oreg_temp, outh); + jge(next, T_NEAR); + + mov(oreg_out_j, oreg_dst); + imul(oreg_temp, oreg_temp, outw * simd_w * typesize); + add(oreg_out_j, oreg_temp); + + test(oreg_dst, 63); + jnz(unaligned, T_NEAR); + + i_loop(j, true); + jmp(next, T_NEAR); + + L(unaligned); + i_loop(j, false); + + L(next); + } + }; + + auto trans_O_4x4_3x3 = [=]() { + auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){ + vmulps(dst, v1, u1); + vfmadd231ps(dst, v2, u2); + }; + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_T, ptr[param1 + GET_OFF(T)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + + vaddps(zmm_t(0), zmm_O(1), zmm_O(2)); + vaddps(zmm_t(1), zmm_O(3), zmm_O(4)); + vsubps(zmm_t(2), zmm_O(1), zmm_O(2)); + vsubps(zmm_t(3), zmm_O(3), zmm_O(4)); + + vaddps(zmm_T(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_T(0), zmm_T(0), zmm_O(0)); + fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_T(3), zmm_T(3), zmm_O(5)); + + for (int j = 0; j < tile_size; j++) { + vmovups(ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize], zmm_T(j)); + } + } + for (int j = 0; j < tile_size; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + vaddps(zmm_t(0), zmm_T(1), zmm_T(2)); + vaddps(zmm_t(1), zmm_T(3), zmm_T(4)); + vsubps(zmm_t(2), zmm_T(1), zmm_T(2)); + vsubps(zmm_t(3), zmm_T(3), zmm_T(4)); + + vaddps(zmm_O(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_O(0), zmm_O(0), zmm_T(0)); + fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_O(3), zmm_O(3), zmm_T(5)); + + for (int i = 0; i < tile_size; i++) { + vmovups(ptr[oreg_O + (j * tile_size * simd_w + + i * simd_w) * typesize], zmm_O(i)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_O_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::input_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int inpw = is_fwd ? jcp.iw : jcp.ow; + int inph = is_fwd ? jcp.ih : jcp.oh; + int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + int wp_max = inpw + l_pad; + int hp_max = inph + t_pad; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + int G_size = 9; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(1 + G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(ireg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + xor_(ireg_zero, ireg_zero); + vpxord(zmm_zero, zmm_zero, zmm_zero); + + mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(ireg_ydim, 2); // tj * tile_size (==4) + mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(ireg_xdim, 2); // ti * tilesize (==4) + + for (int j = 0; j < alpha; j++) { + mov(ireg_temp, ireg_ydim); + add(ireg_temp, j); + + mov(ireg_mask_j, 0xffff); + cmp(ireg_temp, t_pad); + cmovl(ireg_mask_j, ireg_zero); + cmp(ireg_temp, hp_max); + cmovge(ireg_mask_j, ireg_zero); + + sub(ireg_temp, t_pad); + imul(ireg_temp, ireg_temp, inpw * simd_w * typesize); + mov(ireg_inp_j, ireg_src); + add(ireg_inp_j, ireg_temp); + + for (int i = 0; i < alpha; i++) { + + mov(ireg_temp, ireg_xdim); + add(ireg_temp, i); + + mov(ireg_mask, 0xffff); + cmp(ireg_temp, l_pad); + cmovl(ireg_mask, ireg_zero); + cmp(ireg_temp, wp_max); + cmovge(ireg_mask, ireg_zero); + and_(ireg_mask, ireg_mask_j); + + sub(ireg_temp, l_pad); + shl(ireg_temp, 4 + 2); + + vpxord(zmm_temp, zmm_temp, zmm_temp); + Opmask kmask = Opmask(7); + kmovw(kmask, ireg_mask_32); + vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]); + vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_Iw = [=]() { + + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); + + bool streamout + = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_data_size + ? true : false; + + if (not_tiled) { + mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(ireg_tile_block, ireg_tile_block, + alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize); + } + + mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur, + jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block + * jcp.dimK_reg_block * typesize); + + mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(ireg_tile_block_ur, ireg_tile_block_ur, + jcp.dimK_reg_block * typesize); + + add(ireg_output, ireg_nb_tile_block_ur); + add(ireg_output, ireg_tile_block_ur); + if (not_tiled) + add(ireg_output, ireg_tile_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w + + i * simd_w) * typesize]); + + int j_base_offset = + j * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int i_base_offset = + i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + + if (not_tiled && streamout) + vmovntps(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + else + vmovups(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + } + } + }; + + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, a, b); + vaddps(dst, zmm_temp, c); + }; + + auto trans_I_4x4_3x3 = [=]() { + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_T, ptr[param1 + GET_OFF(T)]); + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w + + i * simd_w) * typesize]); + int j_base_offset = + i * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int idx_base_offset = + idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w) + * typesize],zmm_T(idx)); + } + } + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx + * simd_w) * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w) + * typesize],zmm_I(idx)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_I_4x4_3x3(); + store_Iw(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + if (!mayiuse(avx512_core)) { + return status::unimplemented; + } + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ver = ver_avx512_core; + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Checking conditions not supported by these kernels + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) { + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && (one_of(weights_d.format_kind(), + format_kind::any, format_kind::wino) + || (jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0])); + if (!layout_consistency) + return status::unimplemented; + + return status::success; +} + +void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) { + + /* ----------- dimM reg block ---------------------*/ + auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_reg_block, int current_best) { + int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4; + return (dimM_reg_block >= 1) + && (dimM_reg_block <= max_dimM_reg_block ) + && (dimM_reg_block > current_best); + }; + jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block); + + /* ----------- dimN reg block ---------------------*/ + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return jcp.kernel_kind == embd_bcast + ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best + : dimN_reg_block >= 1 + && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block) + < jcp.nb_reg + && dimN_reg_block > current_best; + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimN, 1, test_cond_dimN_reg_block); +} + +status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { + if (jcp.ver != ver_avx512_core) + return status::unimplemented; + + jcp.kernel_kind = embd_bcast; + + set_kernel_dims_reg_block(jcp); + + /*-------------- L2 blocking for dimN block ---------*/ + + auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_block, int current_best) { + return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0) + && (dimN_block > current_best) + && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) + >= 1.5 * mkldnn_get_max_threads()); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block); + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + + if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2) + && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) { + + /* ------------------- L1 blocking for GEMM --------------*/ + /* -------------------- Choose dimK block ----------------*/ + + auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp, + int dimK_block, int current_best) { + return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block); + + if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) { + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + + /* -------------- Choose dimM block -------------------*/ + auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_block, int current_best) { + return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block, + 0.2, 0.5) && (dimM_block > current_best); + }; + + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, + test_cond_dimM_block); + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + + jcp.sched_policy = WSCHED_DATA_W_SGD; + return status::success; + } + + } + return status::unimplemented; +} + +void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) { + + set_kernel_dims_reg_block(jcp); + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_reg_block, + jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_bis_dimM_block); + jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block + * jcp.dimM_reg_block); + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_reg_block, jcp.dimM_simd_block, .9f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); +} + +status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) { + + jcp.kernel_kind = expl_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK, + .1f, .35f))) { + jcp.kernel_kind = embd_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + } + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.nb_reg = 32; + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + jcp.sched_policy = WSCHED_INVALID; + + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + if (jcp.kernel_kind == embd_bcast) { + jcp.dimM_reg_block = 1; + } + + if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success)) + set_wsched_DATA_W_S_G_D_avx512_core(jcp); + + assert(jcp.sched_policy != WSCHED_INVALID); + return status::success; +} + +bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) + || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_t &src_md, memory_desc_t &weights_md, + const memory_desc_t &dst_md, const primitive_attr_t &attr) { + + status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.ic_reg_block = 1; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + if (cd.prop_kind == mkldnn_forward_inference) { + memory_desc_t expect_wei_md = weights_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_OBaaIBOIio; + wd.r = 3; + wd.alpha = 6; + + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.dimK_reg_block; + wd.oc_block = jcp.dimM_simd_block; + wd.ic2_block = jcp.dimK_block; + wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block; + size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + wd.adj_scale = 1.f; + + if (weights_md.format_kind == format_kind::any) + weights_md = expect_wei_md; + if (weights_md != expect_wei_md) + return status::unimplemented; + } + + return res; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.ic_reg_block = jcp.dimM_reg_block; + jcp.oc_reg_block = 1; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + return res; +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +src_transform_generate() { + constexpr int G_size = 9; + const size_t ifwp = jcp.iw + jcp.l_pad; + const size_t ifhp = jcp.ih + jcp.t_pad; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + xor_(reg_zero, reg_zero); + + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + + for (int j = 0; j < alpha; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.t_pad); + cmovl(reg_maskj, reg_zero); + cmp(reg_ydim, ifhp); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad + imul(reg_src_offset, reg_src_offset, jcp.iw); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + sub(reg_src_offset, jcp.l_pad); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < alpha; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.l_pad); + cmovl(reg_maski, reg_zero); + cmp(reg_xdim, ifwp); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + auto zmm_src = Xbyak::Zmm(31); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src, zmm_src, zmm_src); + vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]); + vmovups(ptr[reg_I], zmm_src); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + add(reg_I, simd_w * typesize); + } + add(reg_ydim, 1); + } + }; + + auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) { + vmovups(dst, c); + vfmadd231ps(dst, a, b); + }; + + auto trans_I_3x3_4x4 = [=]() { + //Use 24 registers + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + mov(reg_T, ptr[reg_transp + GET_OFF(T)]); + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + size_t I_off = (j * alpha + i) * simd_w * typesize; + vmovups(zmm_I(j), ptr[reg_I + I_off]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int j = 0; j < alpha; j++) { + vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize], + zmm_T(j)); + } + + } + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int i = 0; i < alpha; i++) { + size_t dst_off = (j * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur) + * simd_w * typesize; + vmovups(ptr[reg_dst + dst_off], zmm_I(i)); + } + } + }; + + auto compute_transform_SDGtWo = [=]() { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + xor_(reg_tile_count, reg_tile_count); + Label loop_mb, loop_jtiles, loop_itiles, done; + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(done); + + add(reg_dst, simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize); + jmp(loop_mb); + } + L(done); + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, simd_w * typesize); + add(reg_dst, reg_temp); + + Label loop_jtiles, loop_itiles, next_tile_block, next_tile; + L(loop_jtiles); + + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * simd_w * typesize); + size_t tblk_off = alpha * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + * simd_w * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + }; + + preamble(); + init_G(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + compute_transform_SDGtWo(); + else + compute_transform(); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_dst_transform_generate(bool with_bias) { + + constexpr int G_size = 8; + auto zmm_G = [](int i) { + return Xbyak::Zmm(31); + }; + + auto zmm_src = [=](int j, int i) { + return Xbyak::Zmm(G_size + j * 4 + i); + }; + + auto zmm_bias = Xbyak::Zmm(31); + + auto load_src = [=]() { + if (with_bias) vmovups(zmm_bias, ptr[reg_bias]); + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + for (int j = 0; j < tile_size; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.oh); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + imul(reg_src_offset, reg_src_offset, jcp.ow); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < tile_size; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.ow); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i)); + vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]); + if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias, + ptr[reg_src + reg_src_offset]); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + } + add(reg_ydim, 1); + } + if(with_bias) vmovups(ptr[reg_bias], zmm_bias); + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 16 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(j * 4 + i); + }; + + auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) { + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + vmovups(ptr[reg_dst + dst_off], a); + else + vmovntps(ptr[reg_dst + dst_off], a); + }; + + auto trans_W_3x3_4x4 = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < tile_size; i++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7)); + vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4)); + vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4)); + vmovups(zmm_T(5, i), zmm_src(3, i)); + } + + for (int j = 0; j < alpha; j++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7)); + vsubps(zmm_t(5), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(1), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(6), zmm_t(2), zmm_t(4)); + vsubps(zmm_t(2), zmm_t(2), zmm_t(4)); + vmovups(zmm_t(3), zmm_T(j, 3)); + + int alpha_offset = (jcp.oc / jcp.nb_oc) + * (jcp.ntiles / jcp.tile_block) * typesize; + int dst_off = j * alpha * alpha_offset; + movps(reg_dst, dst_off, zmm_t(0)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(5)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(1)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(6)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(2)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(3)); + } + + }; + auto compute_transform_SDGtWo = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done; + + L(loop_oc_ur); + { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + xor_(reg_tile_count, reg_tile_count); + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(tiles_done); + + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize); + jmp(loop_mb); + } + + L(tiles_done); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile; + + L(loop_oc_ur); + { + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * jcp.oc_reg_block * simd_w * typesize); + int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc) + * (jcp.ntiles/jcp.tile_block) * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + preamble(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + compute_transform_SDGtWo(); + } else { + compute_transform(); + } + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_weights_transform_generate(bool first_tile) { + int G_size = 4; + + auto zmm_G = [](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + }; + + auto zmm_src = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto load_src = [=](int i) { + for (int j = 0; j < alpha; j++) { + size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block + * jcp.ic_block * simd_w * simd_w * typesize; + size_t src_off = (j * alpha + i) * alpha_offset; + vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off)); + } + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 6 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i); + }; + + auto zmm_dst = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto zmm_temp = Xbyak::Zmm(31); + + auto store_dst = [=](int j) { + for (int i = 0; i < jcp.kw; i++) { + size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize; + + if (!first_tile) { + vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off)); + vaddps(zmm_dst(i), zmm_dst(i), zmm_temp); + } + vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i)); + } + }; + + auto compute_transform = [=] () { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + + xor_(reg_ic_simd, reg_ic_simd); + Label loop_ic_simd; + L(loop_ic_simd); + { + for (int i = 0; i < alpha; i++) { + load_src(i); + + vaddps(zmm_t(0), zmm_src(1), zmm_src(2)); + vaddps(zmm_t(1), zmm_src(3), zmm_src(4)); + vmovups(zmm_t(2), zmm_src(5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0)); + vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1)); + vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2)); + vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1)); + vsubps(zmm_temp, zmm_src(3), zmm_src(4)); + vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2)); + vmovups(zmm_T(2, i), zmm_t(2)); + vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3)); + } + + for (int j = 0; j < jcp.kh; j++) { + vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2)); + vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4)); + vmovups(zmm_t(2), zmm_T(j, 5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0)); + vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1)); + vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2)); + vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1)); + vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4)); + vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2)); + vmovups(zmm_dst(2), zmm_t(2)); + vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3)); + + store_dst(j); + } + + add(reg_src, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, simd_w * typesize); + add(reg_ic_simd, 1); + cmp(reg_ic_simd, simd_w); + jl(loop_ic_simd); + } + }; + preamble(); + push(reg_EVEX_max_8b_offt); + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + init_G(); + compute_transform(); + pop(reg_EVEX_max_8b_offt); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate( + bool is_first_tile) +{ + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + + auto zmm_srcB = [=] (size_t N_ur){ + return Xbyak::Zmm(N_ur + 1); + }; + + auto broadcastB = [=](size_t K_ur) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast) + * sizeof(float); + vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off)); + } + }; + + auto load_srcA = [=] (size_t K_ur, int M_ur) { + size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_ur * jcp.dimM_simd_block) * sizeof(float); + vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off)); + }; + + auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){ + size_t idx = 1 // zmm_srcA + + jcp.dimN_bcast_ur // zmm_srcB + + M_reg_ur * jcp.dimN_bcast_ur + N_bcast; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + auto prepare_accumm = [=](){ + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + Zmm zmm = zmm_dstC(M_reg_ur, N_bcast); + vpxord(zmm, zmm, zmm); + } + } + }; + + auto store_dstC = [=](){ + /******** Write C back to memory *******/ + for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) { + for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) { + Zmm zmm = zmm_dstC(M_reg, N_ur); + size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_reg * jcp.dimM_simd_block) * sizeof(float); + if (!is_first_tile) { + vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off)); + vaddps(zmm, zmm, Xbyak::Zmm(0)); + } + vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur; + + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + { /************* OC_block (M) loop ***********/ + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + { /*************** IC_block (N) loop *********/ + + mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur); + L(dimN_bcast_ur); + { + prepare_accumm(); + + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + { + /************* nb_tile_ur(K) loop ********/ + for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) { + + broadcastB(K_ur); + + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + load_srcA(K_ur, M_reg_ur); + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) { + vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(), + zmm_srcB(N_bcast)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + store_dstC(); + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float)); + add(reg_dstC, jcp.dimN_bcast_ur + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_nb_dimN_bcast_ur, 1); + jnz(dimN_bcast_ur); + } + + sub(reg_srcB, jcp.dimN_reg_block * sizeof(float)); + add(reg_srcB, jcp.dimK_block + * jcp.dimK_reg_block + * jcp.dimN_reg_block * sizeof(float)); + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + + sub(reg_srcB, jcp.dimN_block + * jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { + +void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) { +/*M params*/ + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + /*N params*/ + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /*K params*/ + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; +} + +status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { + + size_t K_blk_ur, N_blk, M_blk; + /* IS this strategy feasible? */ + auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) { + size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float); + size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size) + && (jcp.dimK / nthreads >= 1.0); + }; + + auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur, + int max_block=1) { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float); + size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance=true; + if (!(jcp.dimK % nthreads)) { + load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0); + } + return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size) + && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size) + && load_balance + && (M_L2_block < L2_cache_size); + }; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&](){ + size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * K_blk_ur * sizeof(float); + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block + * K_blk_ur * sizeof(float); + size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * N_blk * jcp.dimN_reg_block * sizeof(float); + size_t L2_block = M_L2_block + V_L2_block + U_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size); + }; + + if (test_MV_large_enough(jcp)) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimN_bcast_ur = 8; + /*dimK_block and dimK_ur*/ + size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1); + + jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block; + jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block; + for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) { + if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (!(jcp.dimN_block % N_blk)) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (!(jcp.dimM_block % M_blk) && blocking_ok()) { + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented; + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.sched_policy = WSCHED_WEI_SDGtWo; + set_jcp_WEI_params(jcp); + jcp.nthr = nstl::min(mkldnn_get_max_threads(), + jcp.tile_block); + return status::success; + } + } + } + } + } + } + } + return status::unimplemented; +} + +status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimN_bcast_ur = 8; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block; + jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block; + float C1 = 0.0, C2 = 0.0; + float C1_max = 0.5, C2_max = 1.4; + int N_blk, M_blk, K_blk_ur; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&]() -> bool { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float); + bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size) + && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size); + + size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block; + size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block; + size_t nb_K_blk = jcp.dimK / K_blk_ur; + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads; + if (!(nb_K_blk % nthreads)) { + load_balance = load_balance && (nb_K_blk % nthreads == 0); + } + + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float); + + size_t L2_block = V_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size); + return L1_cond && load_balance && L2_cond; + }; + + for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) { + if (jcp.dimK % K_blk_ur == 0) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (jcp.dimN_block % N_blk == 0) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (jcp.dimM_block % M_blk == 0) { + if (blocking_ok()) { + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; + } + } + } + } + } + } + } + jcp.dimK_reg_block = 1; + jcp.dimK_block = 1; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; +} +} // namespace +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) { + if (!mayiuse(avx512_core)) + return status::unimplemented; + else + jcp.ver = ver_avx512_core; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.prop_kind = cd.prop_kind; + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + jcp.mb = src_d.dims()[0]; + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /******************Kernel blocking Parameters ***********/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + + jcp.dimK = jcp.ntiles; + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_SDGtWo(jcp); + if (res == status::unimplemented) { + res = set_wsched_WEI_S_D_Giot_W(jcp); + assert(res == status::success); + } + return res; +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp new file mode 100644 index 0000000000..025a554d92 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp @@ -0,0 +1,291 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP +#define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP + +#include "c_types_map.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + : public jit_generator { + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) { + { + this->weights_transform_data_ker_generate(); + weights_transform_data_ker + = (decltype(weights_transform_data_ker)) this->getCode(); + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->input_transform_data_ker_generate(); + input_transform_data_ker = (decltype(input_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->output_transform_data_ker_generate(); + output_transform_data_ker + = (decltype(output_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *, const int); + void (*input_transform_data_ker)(jit_wino_transform_call_s *); + void (*output_transform_data_ker)(jit_wino_transform_call_s *); + void (*weights_transform_data_ker)(jit_wino_transform_call_s *); + +protected: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(); + void input_transform_data_ker_generate(); + void output_transform_data_ker_generate(); + void weights_transform_data_ker_generate(); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + reg64_t reg_is_beta_zero = abi_param4; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; + + /* registers used for transforms*/ + reg64_t param = abi_param1; + + /* registers used for output_transform_data_ker */ + reg64_t oreg_temp = abi_not_param1; + reg64_t oreg_Ow = r9; + reg64_t oreg_src = r11; + reg64_t oreg_tile_block = r12; + reg64_t oreg_tile_block_ur = r13; + reg64_t oreg_nb_tile_block_ur = r14; + reg64_t oreg_O = r8; + reg64_t oreg_T = r10; + reg64_t oreg_dst = r11; + reg64_t oreg_ydim = r14; + reg64_t oreg_xdim = r15; + reg64_t oreg_out_j = r12; + reg64_t oreg_bias = rbx; + reg64_t imm_addr64 = rax; + + /* registers used for input_transform_data_ker */ + reg64_t ireg_temp = abi_not_param1; + reg64_t ireg_jtiles = rax; + reg64_t ireg_itiles = rbx; + reg64_t ireg_I = r8; + reg64_t ireg_src = r13; + reg64_t ireg_ydim = r14; + reg64_t ireg_xdim = r15; + reg64_t ireg_inp_j = r12; + reg64_t ireg_inp_i = rdx; + reg64_t ireg_mask_j = r11; + reg64_t ireg_mask = rsi; + reg32_t ireg_mask_32 = esi; + reg64_t ireg_zero = r9; + reg64_t ireg_Iw = r9; + reg64_t ireg_T = r10; + reg64_t ireg_tile_block = r12; + reg64_t ireg_tile_block_ur = r13; + reg64_t ireg_nb_tile_block_ur = r14; + reg64_t ireg_output = r15; + + /* registers used for wei transform */ + reg64_t wreg_temp = abi_not_param1; + reg64_t wreg_F = r8; + reg64_t wreg_src = r9; + reg64_t wreg_MT = r15; + reg64_t wreg_M = r14; + reg64_t wreg_dst = r10; + reg64_t wreg_dst_aux = r9; + reg64_t wreg_dst_idx = r8; + reg64_t wreg_Fw = r11; + reg64_t wreg_T = r12; + reg64_t wreg_cnt_j = rdx; + reg64_t wreg_F_aux = r14; + reg64_t wreg_Fw_aux = r15; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel + : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_t &src_md, + memory_desc_t &weights_md, const memory_desc_t &dst_md, + const primitive_attr_t &attr); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode(); + + align(); + const Xbyak::uint8 *addr = getCurr(); + this->src_transform_generate(); + src_transform = (decltype(src_transform))addr; + + if (jcp.with_bias) { + align(); + addr = getCurr(); + this->diff_dst_transform_generate(true); + diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr; + } + + align(); + addr = getCurr(); + this->diff_dst_transform_generate(false); + diff_dst_transform = (decltype(diff_dst_transform))addr; + + if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) { + align(); + addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + align(); + addr = getCurr(); + this->diff_weights_transform_generate(true); + diff_weights_transform = (decltype(diff_weights_transform))addr; + + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + align(); + addr = getCurr(); + this->diff_weights_transform_generate(false); + diff_weights_transform_accum = + (decltype(diff_weights_transform_accum))addr; + }; + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*src_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *); + void (*diff_weights_transform)(jit_wino_transform_call_s *); + void (*diff_weights_transform_accum)(jit_wino_transform_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void src_transform_generate(); + void diff_dst_transform_generate(bool with_bias); + void diff_weights_transform_generate(bool first_tile); + + /*registers common to transforms*/ + reg64_t reg_transp = abi_param1; + reg64_t reg_ti = rbx; + reg64_t reg_tj = abi_not_param1; + reg64_t reg_src = r8; + reg64_t reg_dst = r9; + reg64_t reg_G = rsi; /*TODO: check if this is ok*/ + reg64_t reg_temp = rsi; + + /*registers common to src/diff_dst transform*/ + reg64_t reg_I = r10; + reg64_t reg_ydim = r11; + reg64_t reg_xdim = r12; + reg64_t reg_src_offset = r13; + reg64_t reg_zero = r14; + reg64_t reg_tile_count = r15; + reg64_t reg_maski = rsi; + reg32_t reg_maski_32 = esi; + reg64_t reg_maskj = rdx; + + reg64_t reg_T = rax; + reg64_t reg_oc_ur = rax; + reg64_t reg_ic_simd = r14; + reg64_t reg_bias = r10; + + void gemm_loop_generate(bool is_first_tile); + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r9; + reg64_t reg_dimN_block_loop_cnt = r10; + reg64_t reg_nb_dimN_bcast_ur = r11; + reg64_t reg_dimK_block_loop_cnt = r12; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp new file mode 100644 index 0000000000..002010ffa2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp @@ -0,0 +1,1284 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "jit_generator.hpp" + +#include + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + // Below scales are applied to source and weights data accordingly + // because this winograd implementation + // transforms source which may increase values up to 4x + // and transforms weights which may increase values up to 9/4x + const float adj_src_scale = 1.f / 4.f; + const float adj_wei_scale = 4.f / 9.f; + // Winograd transforms need ic and oc to be multiples of 16 + const int load_block = 16; +} + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + void generate(); + + int reg_inp_ind(int i) { + assert(i < jcp.alpha * jcp.alpha); + return (31 - i); + } + + Xmm vreg_inp(int i) { + return Xmm(reg_inp_ind(i)); + } + + Zmm zmm_inp(int i) { + return Zmm(reg_inp_ind(i)); + } + + Xmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(15 - i); + } + Xmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + + int unsign_val_in_wino_domain; + + Reg64 reg_scratch_src_alpha = rdx; + Xmm xmm_src_alpha = Xmm(0); + Zmm zmm_src_alpha = Zmm(0); + + Reg64 reg_shift = rax; + Xmm xmm_shift = Xmm(1); + Xmm xmm_zero = Xmm(0); + + Reg64 reg_maskx = rbx; + Reg64 reg_masky = rsi; + Reg64 reg_nomask = reg_maskx; +}; + +void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { + Label ic_block_label; + Label end_label; + Label mask_label; + Label nomask_label; + + auto load_src = [=](bool mask) { + for (int y = 0; y < jcp.alpha; y++) { + if (mask) + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm_i = zmm_inp(y * jcp.alpha + x); + Xmm vreg_i = vreg_inp(y * jcp.alpha + x); + int inp_offset = sizeof(uint8_t) + * ((-jcp.t_pad + y) * jcp.iw * jcp.ic + + (-jcp.l_pad + x) * jcp.ic); + if (mask) { + kandw(r_mask, y_mask, x_mask(x)); + vmovdqu8(vreg_i | r_mask | T_z, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } else { + vmovdqu8(vreg_i, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + vpmovzxbd(zmm_i, vreg_i); // to int32 + vcvtdq2ps(zmm_i, zmm_i); // to fp32 + vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha + vcvtps2dq(zmm_i, zmm_i); // to int32 + vpmovusdb(vreg_i, zmm_i); // to u8 + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +# undef READ_PARAM + + mov(reg_maskx, ptr[reg_ptr_v_x_masks]); + mov(reg_masky, ptr[reg_ptr_v_y_masks]); + test(reg_maskx, reg_maskx); + jz(end_label, T_NEAR); // skip kernel if x mask is all 0's + test(reg_masky, reg_masky); + jz(end_label, T_NEAR); // skip kernel if y mask is all 0's + and_(reg_maskx, reg_masky); + mov(reg_nomask, reg_maskx); + not_(reg_nomask); // zero if x and y masks are all 1's + + xor_(reg_shift, reg_shift); + mov(reg_shift.cvt8(), (int8_t)-128); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + } + + mov(reg_scratch_src_alpha, float2int(adj_src_scale)); + + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + vmovq(xmm_src_alpha, reg_scratch_src_alpha); + vbroadcastss(zmm_src_alpha, xmm_src_alpha); + + test(reg_nomask, reg_nomask); + jz(nomask_label, T_NEAR); + load_src(true); + jmp(mask_label, T_NEAR); + L(nomask_label); + load_src(false); + L(mask_label); + + for(int y = 0; y < 4; y++) { + vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2)); + vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2)); + vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1)); + vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3)); + } + for(int x = 0;x < 4; x++) { + vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2)); + vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2)); + vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1)); + vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3)); + } + + vmovd(xmm_shift, reg_shift.cvt32()); + vpxor(xmm_zero, xmm_zero, xmm_zero); + vpshufb(xmm_shift, xmm_shift, xmm_zero); + + for (int i = 0; i < 16; i++) { + int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i); + if (i != unsign_val_in_wino_domain) + vpsubb(vreg_out(i), vreg_out(i), Xmm(1)); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(uint8_t) * load_block); + add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block); + } + dec(reg_ic_block); + jnz(ic_block_label, T_NEAR); + + L(end_label); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id < 8); + return Zmm(31 - id_reg_stg); + } + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Zmm(31 - id_reg_out); + } + Xmm xmm_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Xmm(31 - id_reg_out); + } + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id < 2); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_bias = Zmm(1); + Zmm vreg_prev_dst = Zmm(2); + Zmm zmm_bias_alpha = Zmm(2); + Xmm xmm_bias_alpha = Xmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_scratch_bias_alpha = r15; + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0) + || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0)); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1) + || jcp.dst_dt == data_type::u8; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { + Label oc_block_label; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for(int i = 0; i < 16; i++) { + int internal_offset = sizeof(int32_t) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for(int y = 0; y < jcp.alpha; y++) { + vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1)); + vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2)); + + vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2)); + vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3)); + } + for(int x = 0; x < jcp.m; x++) { + vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1)); + vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2)); + + vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2)); + vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3)); + } + + + if (jcp.with_bias) { + vmovq(xmm_bias_alpha, reg_scratch_bias_alpha); + vbroadcastss(zmm_bias_alpha, xmm_bias_alpha); + + auto bias_addr = ptr [ reg_ptr_bias ]; + switch (jcp.bia_dt) { + case data_type::f32: + case data_type::s32: vmovups(vreg_bias, bias_addr); break; + case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break; + case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break; + default: assert(!"unsupported dst data type"); + } + if (jcp.bia_dt != data_type::f32) + vcvtdq2ps(vreg_bias, vreg_bias); + vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha + } + for(int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]); + for(int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = jcp.typesize_out * + (y * jcp.ow * jcp.oc + x * jcp.oc); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + Xmm xmm = xmm_out(i); + vcvtdq2ps(zmm, zmm); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + if (maybe_relu(0)) + vmaxps(zmm, vreg_zero, zmm); + if (p_sum_scale) { // post_op: sum + vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(vreg_prev_dst | r_mask, addr); break; + case data_type::s8: + vpmovsxbd(vreg_prev_dst | r_mask, addr); break; + case data_type::u8: + vpmovzxbd(vreg_prev_dst | r_mask, addr); break; + default: assert(!"unknown dst_dt"); + } + if (jcp.dst_dt != data_type::f32) + vcvtdq2ps(vreg_prev_dst, vreg_prev_dst); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) + vmaxps(zmm, vreg_zero, zmm); + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(addr, zmm | r_mask); break; + case data_type::s8: + vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + case data_type::u8: + vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + default: assert(!"unknown dst_dt"); + } + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, wino_dst); + READ_PARAM(reg_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +# undef READ_PARAM + + if (jcp.with_bias) + mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale)); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + vpxord(vreg_zero, vreg_zero, vreg_zero); + + for (int i = 0; i < jcp.m; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + + int oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); { + loop_body(); + add(reg_aux_ptr_src, sizeof(int32_t) * load_block); + add(reg_aux_ptr_dst, jcp.typesize_out * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block); + } + dec(reg_oc_block); + jnz(oc_block_label, T_NEAR); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) + { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert(31 - jcp.n2_block * jcp.m_block - i + > (jcp.ver == ver_vnni ? 0 : 2)); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst_b = r13; + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + Reg64 reg_mb = rbx; + Reg64 reg_nnb = abi_not_param1; + Reg64 reg_scratch = rdx; + Reg64 reg_K = rsi; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vreg_tmp, vreg_src, vreg_wei); + vpmaddwd(vreg_tmp, vreg_tmp, vreg_one); + vpaddd(vreg_acc, vreg_acc, vreg_tmp); + } + }; + + preamble(); +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); + READ_PARAM(reg_aux_dst_b, dst_b); +# undef READ_PARAM + + if (jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(vreg_one, _t); + } + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + int offset = jcp.typesize_acc * nb2 * jcp.n_block; + vmovups(vreg_out(nb2, m), + EVEX_compress_addr(reg_aux_dst_b, offset)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + mov(reg_K, jcp.k_chunks); + L(K_loop_label); + { + for (int k = 0; k < jcp.k2_block; k += 4) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int wei_offset + = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = jcp.typesize_in * m * jcp.K; + vpbroadcastd(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src); + } + add(reg_aux_src2, jcp.typesize_in * 4); + add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block); + } + } + dec(reg_K); + jnz(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block); + vmovups(EVEX_compress_addr(reg_aux_dst2, offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K); + add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N); + } + dec(reg_mb); + jnz(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K); + + dec(reg_nnb); + jnz(nnb_loop_label, T_NEAR); + } + + postamble(); +} +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + if (jcp.ver == ver_vnni) { + return (jcp.mb <= mkldnn_get_max_threads() + && (jcp.mb > 4 + && jcp.ic > 64 + && !(jcp.oc > 128 && jcp.ih < 14))) + || jcp.mb > mkldnn_get_max_threads(); + } + return true; +} +} + +status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t +::init_conf(jit_conv_conf_2x3_wino_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core) && + src_d.data_type() == data_type::u8 + && wei_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // block sizes needed for GEMM kernel + jcp.ic_block = 4; + jcp.oc_block = 16; + + bool ok = true + && jcp.ngroups == 1 + && jcp.oc % load_block == 0 && jcp.ic % load_block == 0 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && everyone_is(3, jcp.kh, jcp.kw) + && everyone_is(1, jcp.stride_h, jcp.stride_w) + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad + && one_of(jcp.t_pad, 0, 1) + && one_of(jcp.l_pad, 0, 1); + if (!ok) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_acc = sizeof(int32_t); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + + int aa = jcp.alpha * jcp.alpha; + int L1_cap = get_cache_size(1, true); + int L2_cap = get_cache_size(2, true); + // need 1 extra reg for bcast, and 2 tmp regs for non-vnni + int free_regs = jcp.ver == ver_vnni ? 31 : 29; + + auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float thr_eff; + float Z = (float)jcp.ic + jcp.oc; + float Y = (float)jcp.ic * jcp.oc; + if (small_mb == 0) { // outer par + int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + } else { // inner par + int tranw = iy * ix / jcp.alpha; + int gemmw = aa * (jcp.nb_oc / n2_b); + int tranw_r = rnd_up(tranw, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y); + } + return thr_eff; + }; + + auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float mem_eff, req_mem; + int M = ix * iy / jcp.alpha; + if (small_mb == 0) { // outer parallelization strategy + // memory for wino transforms (other memory has poor reuse) + req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc); + mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f; + } else { // inner parallelization strategy + // memory used during gemm + int N = jcp.oc_block * n2_b; + req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + // memory used during wino transforms + int M_per_thr = div_up(M, jcp.nthr); + req_mem = (float)aa * M_per_thr + * (jcp.ic + jcp.typesize_acc * jcp.oc); + if (req_mem > L2_cap) + mem_eff = 0.1f; + } + return mem_eff; + }; + + auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff, + float mem_eff, float reg_eff) { + // these coefficients are chosen empirically + float mem_fac = 0.1f, reg_fac = 0.2f; + // normalized overhead relative to memory and register components + float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff; + // thread and work components affect all others + tot_eff *= thr_eff * work_eff; + return tot_eff; + }; + + auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff, + int &m_block, int &n2_block, float &tot_eff) { + int M = (ix * iy) / jcp.alpha; + int max_m_block = nstl::min(M, free_regs); + int max_n2_block = nstl::min(jcp.nb_oc, free_regs); + tot_eff = 0.f; + for (int im = max_m_block; im > 0; im--) { + if (M % im) + continue; + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = (im + 1) * in2; + float mem_eff = get_mem_eff(small_mb, ix, iy, in2); + float reg_eff = (float)(im * in2) / (im + in2); + float thr_eff = get_thr_eff(small_mb, ix, iy, in2); + float cur_tot_eff = get_tot_eff( + small_mb, thr_eff, work_eff, mem_eff, reg_eff); + if (jcp.nb_oc % in2 || used_regs > free_regs + || cur_tot_eff <= tot_eff) + continue; + tot_eff = cur_tot_eff; + m_block = im; + n2_block = in2; + } + } + }; + + /* Selecting xb and yb blocking */ + int min_yb = jcp.m; + int min_xb = jcp.m; + int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2)); + int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2)); + float best_eff = 0.f; + for (int ix = min_xb; ix <= max_xb; ix += 2) { + assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2); + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2); + + int m_b[2]; + int n2_b[2]; + bool small_mb; + float inner_eff, outer_eff, work_eff; + + int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix); + work_eff = (float)jcp.oh * jcp.ow / tiled_area; + if (best_eff > 0.f && work_eff < 4.f / 9.f) + continue; // no gain from Winograd transformation + + /* outer parallelization */ + find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff); + + /* inner parallelization */ + find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff); + + small_mb = inner_eff > outer_eff; + float eff = small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.m_block = m_b[small_mb]; + jcp.n2_block = n2_b[small_mb]; + jcp.small_mb = small_mb; + } + } + } + + assert((jcp.m_block + 1) * jcp.n2_block <= free_regs); + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.mb_block = 1; + if (jcp.small_mb) { + // For small mb harness, set mb_block as large as possible subject to + // the constraint that winograd activations fit into available L3 cache + int L3_cap = get_cache_size(3, true); + int M = jcp.xb * jcp.yb / 4; + int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in; + int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc; + int max_mb_block = nstl::min( + jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size)); + for (int i = max_mb_block; i > 1; i--) { + if (jcp.mb % i == 0) { + jcp.mb_block = i; + break; + } + } + } + jcp.nb_mb = jcp.mb / jcp.mb_block; + + jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4; + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block; + + // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4 + // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is + // a multiple of load_block = 16, we just use that for now. + jcp.k2_block = load_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + memory_desc_t expect_wei_md = wei_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::s8; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_aaOIoi; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = adj_wei_scale; + + size_t max_size = types::data_type_size(data_type::s8) * + jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + max_size += types::data_type_size(data_type::s32) * + jcp.alpha * jcp.alpha * jcp.oc; + wd.size = max_size; + + if (wei_md.format_kind == format_kind::any) + wei_md = expect_wei_md; + if (wei_md != expect_wei_md) + return status::unimplemented; + + const int tilesize = jcp.alpha * jcp.alpha; + const int numtiles = jcp.M; + const int alltiles = numtiles * tilesize; + + jcp.size_wino_src + = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K) + / jcp.typesize_in; + jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic; + jcp.size_wino_dst = alltiles * jcp.oc; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +template +status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + pd_t::jit_conf() { + return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr()); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t:: +init_scratchpad() { + auto scratchpad = this->scratchpad_registry().registrar(); + + int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr; + scratchpad.book(key_wino_V, + sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K); + scratchpad.book(key_wino_M, + sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K); + + dim_t scale_count = attr()->output_scales_.count_; + scratchpad.book(key_conv_adjusted_scales, + sizeof(float) * nstl::max(scale_count, 16)); +} + +template +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +template +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +template +const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +adjust_oscales(const memory_tracking::grantor_t &scratchpad) const { + const float *oscales = pd()->attr()->output_scales_.scales_; + auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / (adj_src_scale * adj_wei_scale); + if (count == 1) + utils::array_set(loc_scales, oscales[0] * factor, 16); + else + for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor; + return loc_scales; +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const auto &jcp = kernel_->jcp; + if (jcp.small_mb) + execute_forward_small_mb(src, weights, bias, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, weights, bias, dst, this->scratchpad(ctx)); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src_base = scratchpad.template get(key_wino_V); + auto wino_dst_base = scratchpad.template get(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = wino_src_base + jcp.size_wino_src * ithr; + auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + auto gemm_p = + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + mb * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + // start threads at different GEMMs to help bring weights into LLC + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + gemm_p.dst_b = dst_bias + jcp.bia_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + mb * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src = scratchpad.template get(key_wino_V); + auto wino_dst = scratchpad.template get(key_wino_M); + + for (int mbb = 0; mbb < jcp.nb_mb; mbb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t:: + call_params_t(); + + gemm_p.src = wino_src + jcp.inp_stride * tile_ij; + gemm_p.dst = wino_dst + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp new file mode 100644 index 0000000000..9e6e57b051 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp @@ -0,0 +1,128 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t; +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t; +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t; + +template +struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_wino:", avx512_core, ""), + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::u8, data_type::s8, + data_type::undef, dst_data_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_conf(); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + init_scratchpad(); + + return status; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(); + void init_scratchpad(); + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nhwc, any, nhwc); + } + }; + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type acc_data_t; + typedef typename prec_traits::type dst_data_t; + + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd); + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad) + const; + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t *kernel_; + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t *src_trans_; + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp new file mode 100644 index 0000000000..f4ec29ab00 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -0,0 +1,820 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off)); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + int output_offset = jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep; + + add(aux_reg_output_data, output_offset); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in, + zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load) { + return Zmm(ur * load_loop_blk + i_load); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto zmm_bias_alpha = [=]() { + return Zmm(ur * load_loop_blk); + }; + + auto xmm_bias_alpha = [=]() { + return Xmm(ur * load_loop_blk); + }; + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_bia * jcp.oc_block * i_load); + }; + + auto comp_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_comp_data, + sizeof(int32_t) * jcp.oc_block * i_load); + }; + + auto scale_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_ptr_scales, + jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load)); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + + int offt = (jcp.ic_without_padding * i_ur + i_reduce); + + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + + int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + return EVEX_compress_addr(aux_reg_output_data, + jcp.typesize_out * (jcp.oc_without_padding * i_ur + + i_load * jcp.load_block)); + }; + + auto init = [=]() { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } + }; + + auto store = [=](const bool mask_flag_in) { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + if (p_sum_scale && *p_sum_scale != 1.f) { + mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data); + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + } + if (jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_scratch, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_scratch); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; + auto zmm_bias = zmm_tmp; + auto zmm_comp = zmm_bcast; + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, + EVEX_compress_addr(rsp,reg_bias_data_off)); + cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vcvtdq2ps(r, r); + if (jcp.signed_input) + vaddps(r, r, zmm_comp); + if (jcp.with_bias) + vaddps(r, r, zmm_bias); + + zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r; + vmulps(mask_zmm, r, scale_ptr(i_load)); + } + } + + if (maybe_eltwise(0)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + if (p_sum_scale) { // post_op: sum + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto zmm_prev_dst = zmm_zero; + + auto r = vreg_accum(i_load, i_ur); + cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur), + mask_flag); + + if (*p_sum_scale == 1.f) + vaddps(r, zmm_prev_dst); + else + vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + + if (maybe_eltwise(1)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(r, zmm_zero, r); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(r, r); + } + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + zmm_t r_zmm = mask_flag ? r | ktail_mask : r; + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::s8: + vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::u8: + vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off)); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + auto fma_block = [=](bool last_block) { + int reduce_step = 4; + int tail_size = jcp.ic_without_padding % reduce_step; + int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding + ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step) + : jcp.reduce_loop_unroll; + for (int i_reduce = 0; i_reduce < loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load)); + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (last_block && tail_size != 0 + && i_reduce == loop_unroll - reduce_step) { + Xmm xmm_bcast = Xmm(zmm_bcast.getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data + + jcp.ic_without_padding * i_ur + i_reduce + r], r); + vpbroadcastd(zmm_bcast, xmm_bcast); + } else { + vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false)); + } + if (jcp.signed_input) + vpsubb(zmm_bcast, zmm_bcast, zmm_shift); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + compute(vreg_accum(i_load, i_ur), + vreg_load(i_load), zmm_bcast); + } + } + } + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + if (jcp.ic != jcp.ic_without_padding) { + fma_block(true); + } else { + fma_block(false); + } + + if (jcp.oc_without_padding != jcp.oc) { + Label end_store, common_store; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + + /*Check if it is the last load_loop_blk*/ + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + cmp(reg_load_loop_work, 0); + jg(common_store, T_NEAR); + + /*Check if it is the last ocb*/ + test(reg_reduce_pos_flag, FLAG_OC_LAST); + jz(common_store, T_NEAR); + + store(true); + jmp(end_store, T_NEAR); + + L(common_store); + store(false); + + L(end_store); + + add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + } else { + store(false); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() +{ + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + sub(rsp, stack_space_needed); + + if (jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_last_load.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + if (jcp.signed_input) { + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_bia); + if (jcp.signed_input) + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + add(reg_comp_data, + load_loop_blk * jcp.load_block * sizeof(int32_t)); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + add(reg_ptr_scales, + jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float)); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + add(reg_output_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast); + const int *ur_cases_fma = ur_cases_fma_expl_bcast; + const int *ur_cases = ur_cases_fma; + const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + + for (int _i = 1; _i <= label_idx + 1; _i++) { + prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]); + prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]); + } + + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( + jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_core)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!one_of(src_d.data_type(), data_type::u8, data_type::s8) + || weights_d.data_type() != data_type::s8 + || !one_of(dst_d.data_type(), + data_type::f32, data_type::s32, data_type::s8, data_type::u8)) + return status::unimplemented; + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + format_tag_t dat_tag = format_tag::nhwc; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 16; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + const int SMALL_SPATIAL = 7 * 7; + const int BIG_REDUCE_DIM = 1024; + + int load_blocking = 0; + int load_blocking_max = 0; + int bcast_blocking = 0; + int bcast_blocking_max = 0; + int reduce_blocking = 0; + int reduce_blocking_max = 0; + jcp.load_grp_count = 1; + jcp.use_vmovntps = false; + + const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in); + const int L2_capacity = (L2_size * 3) / 4; + + int size_treshold = 28; + int max_regs = 0; + int min_regs = 6; + if (jcp.ver == ver_vnni) + max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold) + && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9; + else + max_regs = 8; + jcp.expl_bcast = true; + + if (jcp.mb == 1 && jcp.ic > 128 + && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) { + if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size) + max_regs = min_regs; // mobilenet_v2 performance improvement + jcp.ur = nstl::min(max_regs, jcp.os); + } else { + const int spatial = jcp.oh; + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i--) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + } + + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + jcp.load_loop_iter_step = jcp.load_block; + + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + reduce_blocking = nb_reduce; + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 64; + else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + + bool cmp_reduce = reduce_blocking <= jcp.reduce_dim; + if (cmp_reduce) + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + load_blocking = jcp.load_dim; + + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); // + load_blocking = jcp.load_block; + } + + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + + assert(jcp.reduce_loop_unroll % 4 == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + // miniumum size of load dim chunk for work distribution within threads + jcp.nb_load_chunk = 1; + // peformance improvements for googlenet_v3, mb=1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4 + && jcp.ic * jcp.oc <= L2_size) { + jcp.nb_load_chunk = 4; + jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count); + } + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp new file mode 100644 index 0000000000..22e9732a1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP +#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t) + jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr); + + bool maybe_eltwise(int position); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using mask_t = const Xbyak::Opmask; + + reg64_t reg_bcast_data = r8; + reg64_t reg_ptr_scales = r8; + reg64_t reg_output_data = r9; + reg64_t reg_load_data = r10; + reg64_t reg_ptr_sum_scale = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t reg_bias_data = r12; + reg64_t reg_comp_data = r12; + reg64_t reg_scratch = r13; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = r15; + reg64_t reg_reduce_pos_flag = rax; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t reg_bcast_loop_work = rbx; + reg64_t bcast_loop_iter = rdx; // Note: Fix me + reg64_t reg_load_loop_work = rsi; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reduce_loop_iter = abi_param1; + + reg64_t reg_last_load = r8; + mask_t ktail_mask = k6; + + mask_t vmask = k7; + + Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28); + Xbyak::Zmm zmm_one = Xbyak::Zmm(29); + Xbyak::Zmm zmm_zero = Xbyak::Zmm(30); + Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31); + Xbyak::Zmm zmm_shift = Xbyak::Zmm(30); + + Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31); + Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31); + + int bcast_loop_work_off = 0; + int reg_bias_data_off = 8; + int reg_bcast_data_off = 16; + int reg_load_data_off = 24; + int reg_ptr_sum_scale_off = 32; + int reg_comp_data_off = 40; + int stack_space_needed = 48; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp new file mode 100644 index 0000000000..0bf09fc677 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp @@ -0,0 +1,292 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +namespace { +template +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const T grp_size = utils::div_up(nthr, nx_divider); + const T grp_count = utils::div_up(nthr, grp_size); + + T grp = ithr / grp_size; + T grp_ithr = ithr % grp_size; + T grp_nthr = grp_size; + T first_grps = nthr % grp_count; + if (first_grps > 0 && grp >= first_grps) { + ithr -= first_grps * grp_size; + grp_nthr--; + grp = ithr / grp_nthr + first_grps; + grp_ithr = ithr % grp_nthr; + } + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} + +/* convolution forward */ +template +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { + auto local_scales = scratchpad.template get( + key_conv_adjusted_scales); + auto scales = pd()->attr()->output_scales_.scales_; + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, scales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = scales[c] * factor; + } + } + + parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); +} + +template +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t +::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + auto local_scales = scratchpad.get(key_conv_adjusted_scales); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + const int stride_h = pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[1]; + const int pad_t = pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][1]; + + const auto &oscales = pd()->attr()->output_scales_; + + int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block) + * jcp.oc_block * jcp.ic_block; + wei_data_t *w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(w + offset) : 0; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t::call_params_t(); + const int nb_oc = jcp.nb_load; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, + jcp.load_grp_count); + if (jcp.nb_load_chunk > 1) { + ocb_start *= jcp.nb_load_chunk; + ocb_end *= jcp.nb_load_chunk; + } + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + + if (ocb + load_step >= nb_oc) + p.first_last_flag |= FLAG_OC_LAST; + else + p.first_last_flag &= ~FLAG_OC_LAST; + + }; + + auto init_reduce = [&]() + { + p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int n, int g, int oh, int ow, + int ih, int iw) + { + const int icb = 0; // Start from the first IC block + const int _ocb = g * nb_oc + ocb; + const int _icb = g; + + const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow); + + p.output_data = &dst[dst_off]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; + p.compensation = (jcp.signed_input) + ? &compensation[_ocb * jcp.oc_block] : 0; + p.scales = (jcp.signed_input && jcp.ver != ver_vnni) + ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block] + : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block]; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + init_reduce(); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + init_reduce(); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + +using namespace data_type; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp new file mode 100644 index 0000000000..ad9027ac17 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_1x1:", avx512_core, ""), + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t< + src_type, dst_type>); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), format_tag::any, + dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_core_x8s8s32x_1x1_conv_kernel:: + init_conf(jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), + *weights_md(1), *attr(), mkldnn_get_max_threads(), + rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + scratchpad, jcp_, *attr()); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + format_tag_t wei_tag = with_groups() ? gOIhw4i16o4i : OIhw4i16o4i; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_core_x8s8s32x_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp new file mode 100644 index 0000000000..e89d068302 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t + : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t); + + status_t init_convolution() { + convolution_desc_t cd; + status_t status; + + auto dd = desc(); + status = conv_desc_init(&cd, prop_kind::forward_training, + alg_kind::convolution_direct, &(dd->src_desc), + &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc), + dd->strides, dd->dilates, dd->padding[0], dd->padding[1], + dd->padding_kind); + + if (status == status::success) { + status = mkldnn_primitive_desc::create( + &conv_pd_, (op_desc_t *)&cd, &attr_, engine_, nullptr); + } + + if (status == status::success) + status = set_default_params(); + + return status; + }; + + status_t init() { + bool ok = true + && is_fwd() + && desc()->alg_kind == alg_kind::deconvolution_direct + && !has_zero_dim_memory() + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && desc()->weights_desc.data_type == data_type::s8 + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + CHECK(init_convolution()); + + return status::success; + } + + virtual void init_scratchpad_md() override { + const auto conv_1x1_pd = static_cast(conv_pd_); + scratchpad_md_ = *conv_1x1_pd->scratchpad_md(); + } + + protected: + status_t set_default_params() { + auto conv_1x1_pd_ = static_cast(conv_pd_); + src_md_ = *conv_1x1_pd_->src_md(); + dst_md_ = *conv_1x1_pd_->dst_md(); + weights_md_ = *conv_1x1_pd_->weights_md(); + if (with_bias()) + bias_md_ = *conv_1x1_pd_->weights_md(1); + return status::success; + } + + using conv_pd_t = typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t + ::pd_t; + friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t; + primitive_desc_t *conv_pd_; + }; + + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + + ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t() + { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + return conv_p_->execute(ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +} +} +} + +#endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp new file mode 100644 index 0000000000..10e98a00c4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -0,0 +1,1182 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { +void pick_loop_order(jit_conv_conf_t &jcp, int nthr) +{ + jcp.loop_order = loop_cwgn; + if (jcp.ngroups > 1) { + jcp.loop_order = loop_ngcw; + if (jcp.mb < nthr) + jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; + } +} +} + +template +bool _jit_avx512_core_x8s8s32x_fwd_kernel::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::prepare_output(int ur_w) +{ + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + for (int k = 0; k < nb_oc_block; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + if (jcp.is_depthwise && !jcp.is_fast_depthwise) { + Reg32 _t32 = reg_scratch.cvt32(); + mov(_t32, (uint32_t)128); + vpbroadcastd(vmm_shift, _t32); + } else { + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)128); + vpbroadcastb(vmm_shift, _t8); + } + } +} + +template +const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel:: + vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) { + return vmm_in; +} + +template<> +const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel:: + vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) { + return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) + : zmm_in; +} + + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in, + const Vmm vmm_in, const Operand &op, bool mask_flag) { + //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in; + const Vmm vmm = vmm_mask(vmm_in, mask_flag); + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(vmm, op); break; + case data_type::s8: vpmovsxbd(vmm, op); break; + case data_type::u8: vpmovzxbd(vmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(vmm_in, vmm_in); +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + if (ur_w == jcp.ur_w) + eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w); + else + for (int k = 0; k < nb_oc_block; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::store_output( + int ur_w, bool last_oc_block_flag) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; + + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = nullptr; + if (sum_idx != -1) { + const auto &p_entry = p.entry_[sum_idx]; + p_sum_scale = &p_entry.sum.scale; + } + + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.signed_input && jcp.ver != ver_vnni) { + /* put 'wei_adj_scale = 0.5' for bias calculation */ + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block); + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * k * oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + + cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + /* bias *= 0.5 */ + vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * k * oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + + cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag); + } + /* add to zmm_accum: compensation, bias and permute */ + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.is_fast_depthwise) + vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k)); + vcvtdq2ps(vmm, vmm); + if (jcp.signed_input) + vaddps(vmm, vmm, vmm_comp); + if (jcp.with_bias) + vaddps(vmm, vmm, vmm_bias); + + const Vmm vmm_k = vmm_mask(vmm, mask_flag); + vmulps(vmm_k, vmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + + /* Do post-ops */ + if (maybe_eltwise(0)) compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + Vmm vmm = vmm_out(j, k); + cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(vmm, vmm_prev_dst); + else + vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) compute_eltwise(ur_w); + + /* write out register to output_addr */ + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.dst_dt == data_type::u8) { + vpxord(vmm_zero, vmm_zero, vmm_zero); + vmaxps(vmm, vmm_zero, vmm); + } + + if (jcp.dst_dt != data_type::f32) { + /* Note: using Zmm for rounding in Xmm/Ymm kernel + because there is no instruction to do rounding + from Xmm/Ymm -> Xmm/Ymm. + Embedded rounding is not supported for Xmm. + TODO: maybe avoid Zmm if it helps performance.*/ + Zmm zmm = zmm_out(j, k); + vcvtps2dq(zmm, zmm); + } + } + + for (int j = 0; j < ur_w; j++) { + int aux_output_offset = jcp.typesize_out + * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + + Vmm vmm = vmm_out(j, k); + const Vmm r_vmm = vmm_mask(vmm, mask_flag, true); + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_vmm); break; + case data_type::s8: vpmovsdb(addr, r_vmm); break; + case data_type::u8: vpmovusdb(addr, r_vmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + assert(!"invalid group blocking for depthwise convolution"); +} + +template <> +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + + auto input_spatial_index = [=](int oi, int ki) { + return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l); + }; + + auto input_offset2 = [=](int ii, int ci) { + return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block); + }; + + auto input_offset3 = [=](int oi, int ci, int ki) { + return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci); + }; + + auto kernel_offset = [=](int ci, int ki) { + return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + // okay for depthwise since src is zero-extended + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddwd(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + int ii_start = 0; + int ii_end = -1; + if (jcp.is_resrc_depthwise && !h_padded) { + // find bounds of input spatial indices + bool first = true; + for (int ki = 0; ki < jcp.kw; ki++) { + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + for (int oi = oi_start; oi < oi_end; oi++) { + int ii = input_spatial_index(oi, ki); + if (first || ii < ii_start) + ii_start = ii; + if (first || ii > ii_end) + ii_end = ii; + first = false; + } + } + } + + if (jcp.signed_input) { + vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero); + vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift); + } + for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) { + const bool mask_flag = last_ic_block_flag != no_last_block + && ci == jcp.nb_ch_blocking - 1; + if (jcp.is_resrc_depthwise && !h_padded) { + // now we can load input once and reuse up to jcp.kw times + for (int ii = ii_start; ii <= ii_end; ii++) { + int aux_input_offset = input_offset2(ii, ci); + const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking); + const Zmm zmm_inp_msk = mask_flag + ? zmm_inp_tmp | ktail_mask | T_z + : zmm_inp_tmp; + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } else { + vpmovzxbd(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift); + } + } + for (int ki = 0; ki < jcp.kw; ki++) { + int aux_kernel_offset = kernel_offset(ci, ki); + if (jcp.is_fast_depthwise) { + vbroadcasti32x4(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei); + } else { + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + if (h_padded) { + assert(jcp.signed_input); + for (int oi = 0; oi < ur_w; oi++) + compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); + } else { + const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src; + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + int start_ = jcp.signed_input ? 0 : oi_start; + int end_ = jcp.signed_input ? ur_w : oi_end; + for (int oi = start_; oi < end_; oi++) { + if (oi >= oi_start && oi < oi_end) { + if (jcp.is_resrc_depthwise) { + int ii = input_spatial_index(oi, ki); + zmm_src = zmm_inp(ii, jcp.nb_ch_blocking); + } else { + int aux_input_offset = input_offset3(oi, ci, ki); + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } else { + vpmovzxbd(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_src, zmm_src, vmm_shift); + } + } else if (jcp.signed_input) { + zmm_src = zmm_shifted_zero; + } + compute(zmm_out(oi, ci), zmm_wei, zmm_src); + } + } + } + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, + int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + if (jcp.is_depthwise) + return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); + + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ch_block_all = jcp.ch_block * ic_block * oc_block; + + int nb_oc_block = jcp.nb_oc_blocking; + + auto input_offset = [=](int oi, int ic, int ki) { + return jcp.typesize_in + * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * jcp.ic_without_padding * jcp.ngroups + 4 * ic); + }; + auto kernel_offset = [=](int ii, int ic, int ki) { + return jcp.typesize_in + * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all + + 4 * ic * oc_block); + }; + auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); + vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); + vpaddd(vreg_acc, vreg_acc, vmm_tmp); + } + }; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + int tail_size = jcp.ic_without_padding % 4; + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */ + int icb = (last_ic_block_flag != no_last_block) + ? div_up((jcp.ic_without_padding % ic_block), 4) + : ic_block / 4; + for (int ic = 0; ic < icb; ic++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Vmm inp = vmm_inp(0,nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } else { + for (int jj = _start; jj < _end; jj++) { + int aux_input_offset = input_offset(jj, ic, ki); + if (jj >= jj_start && jj < jj_end) { + if (last_ic_block_flag == last_sp_block + && tail_size != 0 && ic == icb - 1) { + Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_inp + aux_input_offset + r], r); + vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp); + } else { + vpbroadcastd(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr( + aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(vmm_inp(jj, nb_oc_block), + vmm_inp(jj, nb_oc_block), vmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Vmm inp = vmm_inp(jj, nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } + } + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, ic, ki); + vmovups(vmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + for (int jj = _start; jj < _end; jj++) { + Vmm inp = (h_padded == true) + ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block); + compute(vmm_out(jj, ii), vmm_wei, inp); + } + } + } + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) { + Label kh_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, + b_overflow_label, no_b_overflow_label; + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ic_without_padding * jcp.ngroups; + + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + if ((jcp.signed_input) || (!jcp.signed_input && + (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) { + cmp(reg_kj, 0); + je(skip_kh_loop, T_NEAR); + } + L(kh_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_sp_block) +{ + prepare_output(ur_w); + + // IC loop + Label icb_label; + mov(reg_icb, jcp.nb_ic); + L(icb_label); + if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + + cmp(reg_icb, 1); // The last IC block + jne(common_ker, T_NEAR); + + kh_loop(ur_w, pad_l, pad_r, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, pad_l, pad_r, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, pad_l, pad_r, no_last_block); + } + // End of IC Loop + int inp_step = jcp.ic_block; + int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block; + add(reg_inp, jcp.typesize_in * inp_step); + add(reg_ker, jcp.typesize_in * ker_step); + + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_label, T_NEAR); + + sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); + sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + + jne(common_store, T_NEAR); + + store_output(ur_w, true); // last oc block + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + } else { + store_output(ur_w, false); + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() +{ + Label permute_index_table; + int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad) + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift = jcp.typesize_in * + (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding + * jcp.ngroups); + int out_shift = jcp.typesize_out * + (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); + preamble(); + + if (jcp.is_depthwise) { + int idx = jcp.max_regs_ur - 1; + if (!jcp.is_resrc_depthwise) + zmm_src = Zmm(++idx); + if (jcp.ver != ver_vnni) + zmm_tmp = Zmm(++idx); + if (jcp.is_fast_depthwise) + zmm_permute = Zmm(++idx); + if (jcp.signed_input) { + zmm_shifted_zero = Zmm(++idx); + ++idx; // due to extra register used for shifts and compensations + } + assert(idx == ker_dw_reg_base_idx); + } + + if (!jcp.is_depthwise && jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t16 = reg_scratch.cvt16(); + mov(_t16, 0x1); + vpbroadcastw(vmm_one, _t16); + } + + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise + ? jcp.ngroups % jcp.ch_block + : jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + Reg32 regw_tmp = reg_oi.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + if (jcp.is_fast_depthwise) { + // prepare mask register for blending weights + mov(reg_scratch, 0x8888444422221111); + kmovq(kblend_mask, reg_scratch); + // load permute indices from data section + mov(reg_scratch, permute_index_table); + vmovdqu32(zmm_permute, ptr[reg_scratch]); + } + + int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int n_oi = jcp.ow / jcp.ur_w; + int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + + if (jcp.nb_ow == 1) { + if (r_pad1 > 0 || jcp.ur_w_tail == 0) + n_oi--; + + xor_(reg_oi, reg_oi); + if (jcp.ow == jcp.ur_w) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true); + } else { + if (n_oi == 0) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } else { + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + inc(reg_oi); + } + if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)) + { + Label ow_loop_label; + L(ow_loop_label); { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0 || jcp.ur_w_tail == 0) { + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label, + oi_loop_label, oi_loop_end_label; + + assert(jcp.ow_block % jcp.ur_w == 0); + int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + int n_oi_last_ow_block + = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w; + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded + = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded + = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + mov(reg_oi, n_oi_first_ow_block); + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + L(middle_ow_blocks_label); + + if (jcp.l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + if (n_oi_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + } + + if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + } + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); { + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + icb_loop(jcp.ur_w, 0, 0, false); + + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + + jmp(oi_loop_label, T_NEAR); + } + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + // that is last block + if (!last_ow_block_padded) + jmp(tail_label, T_NEAR); + + // last oi block with right padding + L(last_oi_label); + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + // ur_w tail + L(tail_label); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); + + if (jcp.is_fast_depthwise) { + align(64); + L(permute_index_table); + const uint32_t _idx[] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) + dd(_idx[i]); + } +} + +bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) +{ + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) || + (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr, + int nthreads) +{ + using namespace prop_kind; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + bool is_1d = ndims == 3; + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.ic_block = 1; + jcp.oc_block = 1; + } else { + jcp.ch_block = 1; + jcp.ic_block = 16; + jcp.oc_block = 16; + + if (jcp.ngroups == 1) { + /* For non grouped convolutions, pad channels by 16 if needed */ + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) { + /* For grouped convolutions, MKL-DNN doesn't support padding. + Use Ymm when channels per group is multiple of 8, + Xmm when channels per group is multiple of 4 */ + jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4; + jcp.oc_block = jcp.ic_block; + } + if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0) + return status::unimplemented; + } + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core; + jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni + && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16 + // would require byte masking + // for load from src + jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw + && jcp.kw < 4 && jcp.dilate_w == 0; + if (jcp.is_depthwise) { + jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise + - 2 * jcp.signed_input - (jcp.ver != ver_vnni); + } else { + jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28; + } + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + format_tag_t wei_tag; + if (jcp.ic_block == 16 || jcp.ch_block == 16) { + if (is_1d) { + wei_tag = with_groups + ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i + : OIw4i16o4i; + } else { + wei_tag = with_groups + ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i + : OIhw4i16o4i; + } + } else if (with_groups && jcp.ic_block == 8) { + wei_tag = gOIhw2i8o4i; + } else + wei_tag = gOIhw4o4i; + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + // Try to use 4 channel-groups at a time to avoid false sharing (depthwise) + int nb_ch_blocking = 4; + for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--) + if (jcp.nb_ch % nb_ch_blocking == 0) + break; + jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1; + + // If OC blocking is incommensurate with the number of OC blocks (general + // requirement for all convolutions), or if it results in an unrolling + // factor smaller than the left padding (special requirement for SSD:fc6), + // then search for a smaller OC blocking that satisfies both constraints. + auto is_oc_blocking_ok = [&](int block) { + int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1)); + return jcp.nb_oc % block == 0 + && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1; + }; + + // choose nb_oc work chunk size for distribution within threads + int max_threading_nb_oc_chunk = 4; + // Performance improvements for googlenet_v3 and resnet_50 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3 + && jcp.stride_w == 1 && jcp.ic % 64 == 0) + max_threading_nb_oc_chunk = 2; + jcp.nb_oc_blocking_thr_chunk = + nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc); + for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) { + if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) + break; + } + + // choose oc blocking for computational kernel + jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk; + // Performance improvements for googlenet_v3 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + const int size_treshold_for_nb_oc_blocking_reduction = 17; + if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction + && jcp.stride_w == 1 + && !(jcp.kh == 1 && jcp.kw == 3) + && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) { + const int max_nb_oc_blocking = 2; + jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0 + && is_oc_blocking_ok(jcp.nb_oc_blocking)) + break; + } + + if (jcp.is_resrc_depthwise) + jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w) + / (jcp.nb_ch_blocking + jcp.stride_w); + else + jcp.ur_w + = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking + : jcp.nb_oc_blocking + 1); + if (jcp.ow < jcp.ur_w) + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.ow_block = jcp.ow; + int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh + * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk); + float best_thr_eff + = (float)base_work_amount / rnd_up(base_work_amount, nthreads); + int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w); + for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow); + if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block + && best_thr_eff > 0.8f) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + auto work_amount = base_work_amount * nb_ow; + float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads); + if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) { + jcp.ow_block = ow_block; + best_thr_eff = thr_eff; + } + if (best_thr_eff > 0.9f) + break; + } + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0); + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp, nthreads); + + jcp.nb_ic_L2 = jcp.nb_ic; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, (dim_t)jcp.ic_block); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp new file mode 100644 index 0000000000..d8a05ad53e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp @@ -0,0 +1,239 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t) + + enum { STATE_FIRST_DST_LOAD = 0x1U }; + + _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_core_x8s8s32x_fwd_kernel() { + delete eltwise_injector_; + } + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + ker_dw_reg_base_idx = 30, + }; + typedef enum { + no_last_block, + last_ic_block, + last_sp_block, + } ic_block_t; + + /* data regs */ + const Xbyak::Reg64 reg_ptr_scales = rax; + const Xbyak::Reg64 reg_inp = r8; + const Xbyak::Reg64 reg_ker = r9; + const Xbyak::Reg64 reg_out = r10; + const Xbyak::Reg64 aux_reg_inp = r11; + const Xbyak::Reg64 reg_ptr_sum_scale = r11; + const Xbyak::Reg64 aux_reg_ker = r12; + const Xbyak::Reg64 reg_compensation = r14; + /* counter regs */ + const Xbyak::Reg64 reg_bias_alpha = abi_not_param1; + const Xbyak::Reg64 reg_oi = rbx; + const Xbyak::Reg64 reg_bias = rdx; + const Xbyak::Reg64 reg_oc_blocks = rsi; + const Xbyak::Reg64 reg_owb = aux_reg_ker; + const Xbyak::Reg64 reg_scratch = reg_compensation; + const Xbyak::Reg64 reg_kj = reg_ptr_scales; + const Xbyak::Reg64 reg_overflow = reg_ptr_scales; + const Xbyak::Reg64 reg_icb = reg_bias; + + const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); + + const Vmm vmm_wei = Vmm(31); + /* used during bias section of store_output */ + const Vmm vmm_comp = Vmm(30); // only for signed input + const Vmm vmm_bias = Vmm(31); + /* used during post_op sum section of store_output */ + const Vmm vmm_prev_dst = Vmm(31); + /* used during write-out section of store_output */ + const Vmm vmm_zero = Vmm(31); + + /* used in compute_ker (but set during prepare_output) */ + const Vmm vmm_shift = vmm_comp; // only for signed input + /* used in compute_ker (but only for pre-VNNI machines) */ + const Vmm vmm_tmp = Vmm(28); // not used for depthwise + const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise. + + /* registers use only for depthwise + groups are always blocked by 16(padded if needed), + hence use only Zmm registers */ + const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + Xbyak::Zmm zmm_tmp; + Xbyak::Zmm zmm_src; + Xbyak::Zmm zmm_shifted_zero; + Xbyak::Zmm zmm_permute; + + Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Vmm(idx); + } + Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Xbyak::Zmm(idx); + } + Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + Vmm vmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Vmm(nb_c_block * jcp.ur_w); + } + Xbyak::Xmm xmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Xbyak::Xmm(nb_c_block * jcp.ur_w); + } + int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } + + bool maybe_eltwise(int position); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block_flag); + void compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded); + void compute_ker(int ur_w, int pad_l, int pad_r, + ic_block_t last_ic_block_flag, bool h_padded = false); + void compute_eltwise(int ur_w); + void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag); + void icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); + void generate(); + void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op, + bool mask_flag); + const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false); +}; + +struct jit_avx512_core_x8s8s32x_fwd_kernel { + + jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + ymm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 8: + ymm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = ymm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_core_x8s8s32x_fwd_kernel() { + delete xmm_kernel_; + delete ymm_kernel_; + delete zmm_kernel_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + void (*jit_ker)(jit_conv_call_s *); + _jit_avx512_core_x8s8s32x_fwd_kernel *zmm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel *ymm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel *xmm_kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp new file mode 100644 index 0000000000..cdbf333d5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, + nb_groups, n, jcp.mb); + break; + case loop_gncw: + nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, + gg, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.nb_ic * jcp.ic_block; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.dst = dst + dst_d.blk_off(n, g_oc, ow_s); + p.src = src + src_d.blk_off(n, g_ic, iw_s); + p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0); + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.oc_blocks = jcp.is_depthwise ? gb : ocb; + p.kh_padding = jcp.kh; + p.t_overflow = 0; + p.b_overflow = 0; + p.owb = owb; + + kernel_->jit_ker(&p); + + ++start; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups, + n, jcp.mb); + break; + case loop_gncw: + nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg, + nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ch_block == 1); + assert(jcp.nb_ch_blocking == 1); + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; + int nb_groups = jcp.nb_ch; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + occ, oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; + occ1 += jcp.nb_oc_blocking) { + int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + + int g_ic = g * jcp.nb_ic * jcp.ic_block; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + if (jcp.loop_order == loop_nhwcg) + oh_e = oh_s + 1; // step instead + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias + ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) + : 0; + int32_t *compensation_w = (jcp.signed_input) + ? compensation + g_oc : 0; + + auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + + for (int oj = oh_s, ij = ih_s; oj < oh_e; + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, + div_up(max(0, -ij), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = (!jcp.signed_input) + ? i_t_overflow * wht_h_stride : 0; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = ocb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + src_w += src_h_stride * jcp.stride_h; + dst_w += dst_h_stride; + } + } + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + ++start; + nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ, + oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ic_block == 1); + assert(jcp.oc_block == 1); + assert(jcp.nb_ic == 1); + assert(jcp.nb_oc == 1); + assert(jcp.nb_oc_blocking == 1); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + + parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, + [&](int n, int oh_s, int owb, int gg) { + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0; + int32_t *compensation_w = jcp.signed_input ? compensation + g : 0; + + auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, gb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g]; + + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, + div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = gb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + }); +} + +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::f32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::f32>; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp new file mode 100644 index 0000000000..203ebdf942 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8:", avx512_core, ""), + jit_avx512_core_x8s8s32x_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of(bias_md_.data_type, + data_type::f32, data_type::s32, data_type::s8, + data_type::u8)) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~jit_avx512_core_x8s8s32x_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + const auto &_pd = pd(); + if (_pd->ndims() == 3) + execute_forward_1d(ctx); + else if (_pd->jcp_.is_depthwise) + execute_forward_2d_dw(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_2d_dw(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp new file mode 100644 index 0000000000..142af1f541 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -0,0 +1,1034 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_x8s8s32x_deconvolution.hpp" + +#define GET_OFF(field) offsetof(jit_deconv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +using namespace nstl; + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \ + (d).blk_off(__VA_ARGS__)) + +status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const deconvolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, const bool with_bias, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper bias_d(&bias_md); + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + jcp.signed_input = src_d.data_type() == data_type::s8; + const int ndims = jcp.ndims = dst_d.ndims(); + const bool is_1d = ndims == 3; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; + jcp.is_depthwise = true && with_groups + && utils::everyone_is(1, jcp.ic_without_padding, + jcp.oc_without_padding); + + /* TODO: future work, on hold until depthwise specialized kernel is + * implemented. */ + if (jcp.is_depthwise && jcp.signed_input) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + + format_tag_t wei_tag = is_1d + ? (jcp.is_depthwise + ? Goiw16g : (with_groups ? gOIw4i16o4i : OIw4i16o4i)) + : (jcp.is_depthwise + ? Goihw16g : (with_groups ? gOIhw4i16o4i : OIhw4i16o4i)); + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input && !jcp.is_depthwise) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + jcp.with_bias = with_bias; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.prop_kind = cd.prop_kind; + jcp.mb = src_d.dims()[0]; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.oc_block = 1; + jcp.ic_block = 1; + } else { + jcp.ch_block = 1; + jcp.oc_block = 16; + jcp.ic_block = 16; + + if (jcp.ngroups == 1) { + jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block); + jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block); + } + if (jcp.ic % jcp.ic_block != 0) + return status::unimplemented; + } + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1) + || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1)) + return status::unimplemented; + + /* padding: bottom and right */ + jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.oh + jcp.t_pad - 1); + jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.ow + jcp.l_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.dst_dt = dst_d.data_type(); + jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + /* kernel blocking params */ + const int regs = jcp.ver == ver_vnni ? 30 : 28; + jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc % jcp.nb_oc_blocking == 0 + && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1)) + break; + + jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + + if (jcp.ow < jcp.ur_w) { + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = 0; + } else { + for (; jcp.ur_w >= 1; jcp.ur_w--) { + /* ur_w should be multiple of stride_w in order + to simplify logic for get_ow_start and get_ow_end */ + bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0; + + /* boundary conditions: + These conditions ensure all elements close to boundary + are computed in a single call of compute loop */ + bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + int r_overflow_no_tail + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + bool right_boundary_covered + = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w; + + if (is_multiple_of_stride && left_boundary_covered + && right_boundary_covered) + break; + else if (jcp.ur_w == 1) + /* The boundary conditions above are also important + to maintain simplicity of calls to icb_loop, + if those conditions are not satisfied, + then special cases will need to be added + to use correct l_overflow/r_overflow values + when different iterations of compute loop + work on the locations close to boundary. + So to keep code simple, return unimplemented + for extreme case when a good ur_w cannot be found. + */ + return status::unimplemented; + } + } + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn; + return status::success; +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + eltwise_injector_->compute_vector_range(0, nb_oc_block * ur_w); +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: + return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag, + bool h_padded) { + + const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w; + + auto src_offset = [=](int oj, int icb, int ki) { + return jcp.typesize_in + * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) + * jcp.ngroups * jcp.ic_without_padding + + icb * 4); + }; + + auto kernel_offset = [=](int ocb, int icb, int ki) { + return jcp.typesize_in + * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all + + icb * jcp.oc_block * jcp.ic_block / 4 + + ki * ch_block_all); + }; + + auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else if (jcp.is_depthwise) { + vpmulld(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + for (int ki = 0; ki < jcp.kw; ki++) { + + int jj_start = get_ow_start(ki, l_overflow); + int jj_end = get_ow_end(ur_w, ki, r_overflow); + + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + + int tail_size = jcp.ic_without_padding % 4; + int n_ic_blocks = jcp.is_depthwise ? + 1 : + (last_ic_block_flag & ~no_last_block ? + div_up(jcp.ic_without_padding % jcp.ic_block, + 4) : + jcp.ic_block / 4); + + for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Zmm inp = zmm_inp(0, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } else { + + for (int jj = _start; jj < _end; jj += ur_w_stride) { + + int aux_src_off = src_offset(jj, icb1, ki); + + if (jj >= jj_start && jj < jj_end + && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) { + if (jcp.is_depthwise) { + vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } else if ((last_ic_block_flag & last_sp_block) + && tail_size != 0 && icb1 == n_ic_blocks - 1) { + xmm_t xmm_tmp = xmm_t( + zmm_inp(jj, jcp.nb_oc_blocking).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_src + aux_src_off + r], r); + vpbroadcastd( + zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp); + } else { + vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } + if (jcp.signed_input) + vpsubb(zmm_inp(jj, jcp.nb_oc_blocking), + zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } + } + } + } + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + int aux_filt_off = kernel_offset(ocb, icb1, ki); + + if (_end - _start > 0) { + if (jcp.is_depthwise) + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + else + vmovups(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + } + for (int jj = _start; jj < _end; jj += ur_w_stride) { + Zmm inp = (h_padded == true) ? + zmm_inp(0, jcp.nb_oc_blocking) : + zmm_inp(jj, jcp.nb_oc_blocking); + compute(zmm_out(jj, ocb), zmm_wei, inp); + } + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) { + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ngroups * jcp.ic_without_padding; + const int stride_h = jcp.signed_input ? 1 : jcp.stride_h; + int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h; + + Label kh_loop_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, b_overflow_label, + no_b_overflow_label; + + mov(aux_reg_src, reg_src); + mov(aux_reg_filt, reg_filt); + + if (jcp.signed_input && jcp.ndims > 3) { + /* Weights are transposed, so first compute 'bottom' padding. */ + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } + + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + if (jcp.signed_input || ((!jcp.signed_input) + && ((min(jcp.t_pad, jcp.b_pad) < 0) + || ((jcp.kh - 1) * (jcp.dilate_h + 1) + < nstl::max(jcp.t_pad, jcp.b_pad))))) { + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + } + + L(kh_loop_label); { + compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false); + sub(aux_reg_src, shift_src_ih); + add(aux_reg_filt, shift_filt_kh); + dec(reg_kh); + + /* Insert weight compensation in stride 'holes' */ + if (jcp.signed_input && jcp.stride_h > 1) { + Label kh_comp_loop; + + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + mov(reg_comp_strides, jcp.stride_h - 1); + L(kh_comp_loop); + { + compute_ker( + ur_w, 0, 0, last_ic_block_flag, true); + add(aux_reg_filt, shift_filt_kh); + dec(reg_comp_strides); + cmp(reg_comp_strides, 0); + jg(kh_comp_loop, T_NEAR); + } + } + cmp(reg_kh, 0); + jg(kh_loop_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) { + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vpxord(zmm, zmm, zmm); + } + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps( + data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( + int ur_w, bool last_oc_block) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale + = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + int scale_offset + = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); + + auto zmm_bias = zmm_tmp; + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag); + } + + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vcvtdq2ps(zmm, zmm); + if (jcp.signed_input) + vaddps(zmm, zmm, zmm_comp); + if (jcp.with_bias) + vaddps(zmm, zmm, zmm_bias); + zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm; + vmulps(mask_zmm, zmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + if (maybe_eltwise(0)) + compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + const bool mask_flag + = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * jcp.oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_dst, aux_output_offset); + Zmm zmm = zmm_out(j, k); + cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(zmm, zmm_prev_dst); + else + vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) + compute_eltwise(ur_w); + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(zmm, zmm_zero, zmm); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + } + for (int ur = 0; ur < ur_w; ur++) { + int aux_dst_off = jcp.typesize_out + * (ur * jcp.ngroups * jcp.oc_without_padding + + ocb * jcp.oc_block); + auto addr = EVEX_compress_addr(reg_dst, aux_dst_off); + + zmm_t zmm = zmm_out(ur, ocb); + zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm; + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_zmm); break; + case data_type::s8: vpmovsdb(addr, r_zmm); break; + case data_type::u8: vpmovusdb(addr, r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop( + int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) { + + int shift_src_icb = jcp.typesize_in * jcp.ic_block; + int shift_filt_icb + = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block; + + prepare_output(ur_w); + + Label skip_icb_loop, icb_loop_label; + + mov(reg_icb, jcp.nb_ic); + L(icb_loop_label); { + + if (jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + cmp(reg_icb, 1); + jg(common_ker, T_NEAR); + + kh_loop(ur_w, l_overflow, r_overflow, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + } + + add(reg_src, shift_src_icb); + add(reg_filt, shift_filt_icb); + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_loop_label, T_NEAR); + } + + /* come-back pointers */ + sub(reg_src, jcp.nb_ic * shift_src_icb); + sub(reg_filt, jcp.nb_ic * shift_filt_icb); + L(skip_icb_loop); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - 1); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + jne(common_store, T_NEAR); + + store_output(ur_w, true); + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + + } else { + store_output(ur_w, false); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() { + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise ? + jcp.ngroups % jcp.ch_block : + jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_nur_w.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + mov(reg_src, ptr[param1 + GET_OFF(src)]); + mov(reg_filt, ptr[param1 + GET_OFF(filt)]); + mov(reg_dst, ptr[param1 + GET_OFF(dst)]); + + int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups + * jcp.oc_without_padding; + int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups + * jcp.ic_without_padding; + + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + int r_overflow + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)) + / jcp.stride_w); + + int r_overflow1 + = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + int nur_w = jcp.ow / jcp.ur_w; + if (r_overflow1 > 0) + nur_w--; + + if (jcp.ur_w == jcp.ow) { + icb_loop(jcp.ur_w, l_overflow, r_overflow, true); + } else if (nur_w == 0) { + icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + if (jcp.ur_w_tail != 0) + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } else { + xor_(reg_nur_w, reg_nur_w); + if (l_overflow > 0) { + icb_loop(jcp.ur_w, l_overflow, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + } + if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + cmp(reg_nur_w, nur_w); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + int n{ 0 }, g{ 0 }, occ{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + + p.dst = dst + dst_d.blk_off(n, g_oc); + p.src = src + src_d.blk_off(n, g_ic); + p.filt = weights + wht_blk_off(weights_d, g, ocb, 0); + p.bias = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.t_overflow = 0; + p.b_overflow = 0; + p.kh_padding = jcp.kh; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + + kernel_->jit_ker(&p); + + ++start; + if (jcp.loop_order == loop_ngc) + nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + } + }); +} + +template +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + /*loop order = cgn*/ + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, + oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + int work_rem = end - start; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + auto dst_w = dst + dst_d.blk_off(n, g_oc); + auto src_w = src + src_d.blk_off(n, g_ic); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + auto bias_w = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + int32_t *compensation_w + = (jcp.signed_input) ? compensation + g_oc : 0; + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + for (int oj = oh_s; oj < oh_e; oj++) { + int ih_max = 0, kh_lo = 0, kh_len = 0; + if (jcp.dilate_h != 0 && jcp.stride_h == 1) { + /* dilation */ + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int o_t_overflow = div_up( + max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), + dilate_h); + int o_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh + + oj - jcp.b_pad), + dilate_h); + kh_len = jcp.kh - o_t_overflow - o_b_overflow; + kh_lo = o_b_overflow; + ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; + } else { + int o_t_overflow = max( + 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); + int o_b_overflow + = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) + / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 + - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h; + int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; + + kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h + + 1 - o_t_overflow - o_b_overflow; + kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; + ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; + } + + int wei_stride + = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0; + p.src = src_w + ih_max * src_h_stride; + p.dst = dst_w + oj * dst_h_stride; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.t_overflow = max( + 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h + + 1)); + p.b_overflow = kh_lo; + p.kh_padding = kh_len; + p.scales = scales; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + kernel_->jit_ker(&p); + } + if (jcp.loop_order == loop_ngc) + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, + jcp.mb, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + }); +} + +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp new file mode 100644 index 0000000000..901038fa48 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp @@ -0,0 +1,237 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "cpu_primitive.hpp" +#include "cpu_memory.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "nstl.hpp" + +#include "cpu_deconvolution_pd.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + no_last_block = 0x1U, + last_ic_block = 0x2U, + last_sp_block = 0x4U, +} ker_block_t; + +struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); + + jit_avx512_core_x8s8s32x_deconv_fwd_kernel( + const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + generate(); + jit_ker = (void (*)(jit_deconv_call_s *))getCode(); + } + + ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const deconvolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &weights_md, + memory_desc_t &dst_md, + const bool with_bias, + memory_desc_t &bias_md, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + const jit_conv_conf_t &jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_deconv_call_s *); +private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_src = r8; + reg64_t reg_filt = r9; + reg64_t reg_dst = r10; + reg64_t param1 = abi_param1; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_nur_w = rbx; + reg64_t reg_bias = rdx; + reg64_t reg_icb = reg_bias; + reg64_t reg_ptr_scales = rax; + reg64_t reg_oc_blocks = rsi; + + reg64_t aux_reg_src = r11; + reg64_t aux_reg_filt = r12; + + reg64_t reg_compensation = r14; + reg64_t reg_scratch = r14; + reg64_t reg_ptr_sum_scale = r11; + reg64_t reg_bias_alpha = abi_not_param1; + reg64_t reg_overflow = rax; + reg64_t reg_comp_strides = reg_overflow; + + Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + zmm_t zmm_tmp = zmm_t(28); + zmm_t zmm_one = zmm_t(29); + /* used during write-out section of store_output */ + zmm_t zmm_zero = zmm_t(31); + zmm_t zmm_wei = zmm_t(31); + + /* signed input */ + zmm_t zmm_shift = zmm_t(30); + zmm_t zmm_comp = zmm_t(30); + zmm_t zmm_bias = zmm_t(31); + zmm_t zmm_prev_dst = zmm_t(31); + + zmm_t zmm_out(int i_ur, int i_oc) { + int idx = i_ur * jcp.nb_oc_blocking + i_oc; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_bias_alpha() { + return zmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + xmm_t xmm_bias_alpha() { + return xmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + + int get_ow_start(int ki, int l_overflow) { + int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return res; + } + + int get_ow_end(int ur_w, int ki, int r_overflow) { + if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return ur_w - res; + } + bool maybe_eltwise(int position); + void compute_eltwise(int ur_w); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block); + void compute_ker(int ur_w, int l_overflow, int r_overflow, + ker_block_t last_ic_block_flag, bool h_padded = false); + void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); + void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +template +struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_deconvolution:", avx512_core, ""), + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && (desc()->alg_kind & alg_kind::deconvolution_direct) + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + with_bias(), bias_md_, *attr()); + + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status::success; + } + + jit_conv_conf_t jcp_; + }; + + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~_jit_avx512_core_x8s8s32x_deconvolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if(pd()->ndims() == 3) + execute_forward_1d(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_x8s8s32x_deconv_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp new file mode 100644 index 0000000000..c09592d5c9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp @@ -0,0 +1,773 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_GENERATOR_HPP +#define CPU_JIT_AVX2_GENERATOR_HPP + +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_utils/jit_utils.hpp" + +#if defined(_WIN32) && !defined(__GNUC__) +# define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__ +#else +# define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al))) +#endif + +#if defined(_WIN32) +# define OFFSET_SHADOWSPACE 0x28 +#endif + +#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \ + const char *name() const override { return STRINGIFY(jit_name); } \ + const char *source_file() const override { return __FILE__; } + +namespace mkldnn { +namespace impl { +namespace cpu { + +// TODO: move this to jit_generator class? +namespace { + +typedef enum { + PAGE_4K = 4096, + PAGE_2M = 2097152, +} cpu_page_size_t; + +// TODO: move this somewhere else? Although this is only used by jit kernels +// (Roma) +static inline int float2int(float x) { + union { + float vfloat; + int vint; + } cvt; + cvt.vfloat = x; + return cvt.vint; +} + +// TODO: A GPR class that hides ABI details from the JIT kernels and allows +// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and +// stack register (sr). +// +// This will allow using syntax like this: +// +// param = gpr0; +// reg_input = gpr0; +// reg_output = gpr1; +// ... +// +// #ifndef XBYAK64 +// mov(param, ptr[sr]) +// #endif +// +// (Roma) + +#ifdef XBYAK64 +constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = { + Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, + Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15, +#ifdef _WIN32 + Xbyak::Operand::RDI, Xbyak::Operand::RSI, +#endif +}; + +#ifdef _WIN32 +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX), + abi_param2(Xbyak::Operand::RDX), + abi_param3(Xbyak::Operand::R8), + abi_param4(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RDI); +#else +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI), + abi_param2(Xbyak::Operand::RSI), + abi_param3(Xbyak::Operand::RDX), + abi_param4(Xbyak::Operand::RCX), + abi_param5(Xbyak::Operand::R8), + abi_param6(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RCX); +#endif +#endif + +inline unsigned int get_cache_size(int level, bool per_core = true){ + unsigned int l = level - 1; + // Currently, if XByak is not able to fetch the cache topology + // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core. + if (cpu.getDataCacheLevels() == 0){ + const int L1_cache_per_core = 32000; + const int L2_cache_per_core = 512000; + const int L3_cache_per_core = 1024000; + int num_cores = per_core ? 1 : mkldnn_get_max_threads(); + switch(l){ + case(0): return L1_cache_per_core * num_cores; + case(1): return L2_cache_per_core * num_cores; + case(2): return L3_cache_per_core * num_cores; + default: return 0; + } + } + if (l < cpu.getDataCacheLevels()) { + return cpu.getDataCacheSize(l) + / (per_core ? cpu.getCoresSharingDataCache(l) : 1); + } else + return 0; +} + +} + +class jit_generator : public Xbyak::CodeGenerator +{ +private: + const size_t xmm_len = 16; +#ifdef _WIN32 + const size_t xmm_to_preserve_start = 6; + const size_t xmm_to_preserve = 10; +#else + const size_t xmm_to_preserve_start = 0; + const size_t xmm_to_preserve = 0; +#endif + + const size_t num_abi_save_gpr_regs + = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]); + + const size_t size_of_abi_save_regs + = num_abi_save_gpr_regs * rax.getBit() / 8 + + xmm_to_preserve * xmm_len; + +public: + enum { + _cmp_eq_oq = 0u, + _cmp_lt_os = 1u, + _cmp_le_os = 2u, + _cmp_neq_uq = 4u, + _cmp_nlt_us = 5u, + _cmp_nle_us = 6u, + + _op_floor = 1u, + _op_mxcsr = 4u, + }; + + Xbyak::Reg64 param1 = abi_param1; + const int EVEX_max_8b_offt = 0x200; + const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; + + inline size_t get_size_of_abi_save_regs() { + return size_of_abi_save_regs; + } + + void preamble() { + if (xmm_to_preserve) { + sub(rsp, xmm_to_preserve * xmm_len); + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i)); + } + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + push(Xbyak::Reg64(abi_save_gpr_regs[i])); + if (mayiuse(avx512_common)) { + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + } + } + + void mic_prefetcht0(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht0(a); + } + + void mic_prefetcht1(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht1(a); + } + + void mic_prefetcht2(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht2(a); + } + + void uni_vzeroupper() { + if (mayiuse(avx) && !mayiuse(avx512_mic)) + vzeroupper(); + } + + void postamble() { + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i])); + if (xmm_to_preserve) { + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]); + add(rsp, xmm_to_preserve * xmm_len); + } + uni_vzeroupper(); + ret(); + } + + template + Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, + T raw_offt, bool bcast = false) + { + using Xbyak::Zmm; + using Xbyak::Reg64; + using Xbyak::Address; + using Xbyak::RegExp; + + assert(raw_offt <= INT_MAX); + auto offt = static_cast(raw_offt); + + int scale = 0; + + if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { + offt = offt - 2 * EVEX_max_8b_offt; + scale = 1; + } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { + offt = offt - 4 * EVEX_max_8b_offt; + scale = 2; + } + + auto re = RegExp() + base + offt; + if (scale) + re = re + reg_EVEX_max_8b_offt * scale; + + if (bcast) + return zword_b [re]; + else + return zword [re]; + } + + Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt, + const Xbyak::Reg64 &tmp_reg, bool bcast = false) { + if (offt > INT_MAX) { + mov(tmp_reg, offt); + return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg]; + } else { + return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt]; + } + } + + Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base, + size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) { + if (raw_offt > INT_MAX) { + return make_safe_addr(base, raw_offt, reg_offt, bcast); + } else { + return EVEX_compress_addr(base, raw_offt, bcast); + } + } + + void safe_add(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + add(base, reg_offt); + } else { + add(base, raw_offt); + } + } + + void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + sub(base, reg_offt); + } else { + sub(base, raw_offt); + } + } + + // Disallow char-based labels completely + void L(const char *label) = delete; + void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } + + void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + pxor(x2, op); + } + void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpxor(x1, x2, op); + } else { + vxorps(x1, x2, op); + } + } + void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, + const Xbyak::Operand &op) { + vpxord(x1, x2, op); + } + + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movss(addr, x); + } + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovss(addr, x); + } + void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movss(x, addr); + } + void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovss(x, addr); + } + + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movsd(addr, x); + } + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovsd(addr, x); + } + void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movsd(x, addr); + } + void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovsd(x, addr); + } + + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) { + vmovdqu32(addr, x); + } + + void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) { + movdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) { + vmovdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) { + vmovdqu32(x, addr); + } + + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movups(addr, x); + } + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovups(addr, x); + } + + void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movups(x, op); + } + void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vmovups(x, op); + } + + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movntps(addr, x); + } + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovntps(addr, x); + } + + void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movss(x, op); + shufps(x, x, 0x0); + } + void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (op.isMEM() || mayiuse(avx2)) { + vbroadcastss(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movss(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movsd(x, op); + pshufd(x, x, 0x0); + } + void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpbroadcastd(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movsd(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpss(x, op); + } + void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) { + Xbyak::Xmm x1_(x1.getIdx()); + Xbyak::Xmm x2_(x2.getIdx()); + vrcpss(x1_, x1_, x2_); + } + void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) { + Xbyak::Xmm x_(x.getIdx()); + vrcpss(x_, x_, op); + } + + void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpps(x, op); + } + void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vrcpps(x, op); + } + void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) { + vrcp14ps(x, op); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + divps(x, op2); + } + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vdivps(x, op1, op2); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + divps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vdivps(x, op1, op2); + } + + void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + addps(x, op2); + } + void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vaddps(x, op1, op2); + } + + void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, + const Xbyak::Operand& op) { + assert(x1.getIdx() == x2.getIdx()); + psignd(x1, op); + } + void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2, + const Xbyak::Operand& op) { + vpsignd(x1, x2, op); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + subps(x, op2); + } + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vsubps(x, op1, op2); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + subps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vsubps(x, op1, op2); + } + + void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + mulps(x, op2); + } + void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmulps(x, op1, op2); + } + + void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x1, x2); + addps(x1, op); + } + void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd213ps(x1, x2, op); + } + + void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + addps(x1, x2); + } + void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd231ps(x1, x2, op); + } + + void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + subps(x1, x2); + } + + void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfnmadd231ps(x1, x2, op); + } + + void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + sqrtps(x, op); + } + void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vsqrtps(x, op); + } + + void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + paddd(x2, op); + } + void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + vpaddd(x1, x2, op); + } + + void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + andps(x1, op); + } + void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vandps(x1, x2, op); + else + vpandd(x1, x2, op); + } + + void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + orps(x1, op); + } + void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vorps(x1, x2, op); + else + vpord(x1, x2, op); + } + + void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + pslld(x, imm); + } + void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpslld(x, op, imm); + } + + void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + psrld(x, imm); + } + void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpsrld(x, op, imm); + } + + void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + maxps(x, op2); + } + void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmaxps(x, op1, op2); + } + + void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + minps(x, op2); + } + void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vminps(x, op1, op2); + } + + void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nle_us); + } + + void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpgtps(x1, x2, op); + } + + void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nlt_us); + } + + void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpps(x1, x2, op, _cmp_nlt_us); + } + + void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) { + ptest(x1, op); + } + + void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) { + assert(!(x1.isZMM() || op.isZMM())); + vtestps(x1, op); + } + + void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op, const Xbyak::Xmm &msk) { + assert(x1.getIdx() == x2.getIdx()); + assert(msk.getIdx() == 0); + blendvps(x1, op); + } + void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op, const Xbyak::Ymm &msk) { + vblendvps(x1, x2, op, msk); + } + + void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + roundps(x, op, imm); + } + void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vroundps(x, op, imm); + } + + void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtps2dq(x, op); + } + void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtps2dq(x, op); + } + + void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtdq2ps(x, op); + } + void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtdq2ps(x, op); + } + + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) { + movmskps(x1.cvt64(), x2); + } + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) { + vmovmskps(x1, x2); + } + + void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packssdw(x1, op); + } + void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackssdw(x1, x2, op); + } + + void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packuswb(x1, op); + } + void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackuswb(x1, x2, op); + } + + + void mul_by_const(const Xbyak::Reg &out, + const Xbyak::Reg64 &tmp, int value) { + // Generates a shift + add sequence for multiplicating contents of the + // out register by a known JIT-time value. Clobbers the tmp register. + // + // Pros compared to mul/imul: + // - does not require using known registers + // - not microcoded on Intel(R) Xeon Phi(TM) processors + // Still, there are probably a lot of cases when mul/imul is faster on + // Intel(R) Core(TM) processors. Not intended for critical path. + + // TODO: detect when overflow is emminent (Roma) + // TODO: detect when using mul/imul is a better option (Roma) + + int p = 0; // the current power of 2 + int old_p = 0; // the last seen power of 2 such that value[old_p] != 0 + + xor_(tmp, tmp); + while (value) { + if (value & 1) { + int shift = p - old_p; + if (shift) { + shl(out, shift); + old_p = p; + } + add(tmp, out); + } + value >>= 1; + p++; + } + mov(out, tmp); + } + +public: + jit_generator( + void *code_ptr = nullptr, + size_t code_size = 256 * 1024 + ) : Xbyak::CodeGenerator(code_size, code_ptr) + { + } + virtual ~jit_generator() {} + + virtual const char *name() const = 0; + virtual const char *source_file() const = 0; + + const Xbyak::uint8 *getCode() { + const Xbyak::uint8 *code = CodeGenerator::getCode(); + size_t code_size = getSize(); + jit_utils::register_jit_code(code, code_size, name(), source_file()); + return code; + } + + template const F getCode() { + return (const F)getCode(); + } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp new file mode 100644 index 0000000000..56d7f592e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp @@ -0,0 +1,481 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_PRIMITIVE_CONF_HPP +#define JIT_PRIMITIVE_CONF_HPP + +#include + +#include "common/primitive_attr.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/* convolution */ +enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni}; +enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, + loop_ngcw, loop_nhwcg, loop_nwcg}; +enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, + loop_brl}; +enum conv_kernel_kind_t {embd_bcast, expl_bcast}; + +enum { + FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1, + FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3, + FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5, + FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7, + FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9, + FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips + loading weights-data from memory; this + needs to happen on the first Group/16 + iteration. */ + FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip + loading bias data from memory */ + FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution + pass */ +}; + +struct jit_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + conv_loop_order_t loop_order; + + int simd_w; + int ndims; + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int id, ih, iw, od, oh, ow; + int f_pad, l_pad, t_pad; + int back_pad, r_pad, b_pad; + int kd, kh, kw; + int stride_d, stride_h, stride_w; + int dilate_d, dilate_h, dilate_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + + int idp, ihp, iwp, ohp, owp; + int nb_ic, ic_block; + int nb_oc, oc_block; + int nb_ow, ow_block; + int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking + into account vector registers distribution */ + int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work + within threads */ + int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work + int nb_ic_L2; + int h_blocking; + int nb_oc_L2; + int ur_h, ur_w; + int ur_w_tail; + bool is_1stconv; + int nonblk_group_off; + /* fma avx512_core */ + conv_kernel_kind_t kernel_kind; + /* 4fma */ + int tr_iw; + int tr_src_num_guard_elems; + /* 1st conv: 4fma */ + int tr_ld; + int kh_step; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* avx512_u8s8u8 */ + int ic_nb1, ic_nb2; + int oc_nb1; + int ur_ow_max, ur_ow, ur_ow_tail; + int ur_ow_nsteps; + data_type_t bia_dt; + data_type_t dst_dt; + /* avx512: max possible value is nregs(32) - aux_regs(4) */ + int src_offsets[28]; + int src_count; + bool expl_bcast; + bool large_spatial; + int is_oc_scale; + int max_regs_ur; // maximum accumulation registers + // dw conv + int nb_ch, ch_block, nb_ch_blocking; + bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; + int aligned_threads; + // large spatial + int oh_blk_size; + // s8s8 convolution + bool signed_input; + float wei_adj_scale; +}; + +struct jit_conv_conf_2x3_wino_t { + conv_version_t ver; + + int m; + int r; + int alpha; + int tile_h, tile_w; + + int mb; + int ngroups, ic, oc, oc_without_padding; + int ih, iw, oh, ow; + int l_pad, t_pad; + int r_pad, b_pad; + int kh, kw; + int stride_h, stride_w; + int dilate_h, dilate_w; + + int nb_ic, ic_block; + int nb_oc, oc_block; + + int w_block_size, h_block_size; + + data_type_t bia_dt; + data_type_t dst_dt; + + int is_oc_scale; + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + + format_tag_t src_tag, dst_tag; // temporary workaround + bool with_bias; + bool small_mb; + + int xb, yb; + int inp_stride; + int out_stride; + int wei_stride; + int bia_stride; + + int M, N, K; + int m_block, n_block, k_block; + int n2_block, n_chunks; + int k2_block, k_chunks; + + int mb_block, nb_mb; + + size_t size_wino_src, size_wino_wei, size_wino_dst; + + int nthr; +}; + +/* + Winograd sched policy: + + Computation Unit: + W: weights transform + S: src transform + D: dst transform + G: gemm + + Thread grouping by: + i: nb_ic + o: nb_oc + t: tile_block + e: element in tile + + Note: 'i' and 'o' are omited if + i. not comblined with t or + ii. with discrete transforms + + Current policies supported: +*/ +enum winograd_sched_t { + WSCHED_INVALID = 0, + + /* Forward & backward-data */ + /* W_S_G_D implements discrete transforms */ + WSCHED_DATA_W_S_G_D, + /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ + WSCHED_DATA_W_SGD, + + /* Backward-weights */ + WSCHED_WEI_S_D_G_W, + WSCHED_WEI_SDGtWo, + WSCHED_WEI_S_D_Giot_W, + WSCHED_WEI_SDGt_W, +}; + +struct jit_conv_winograd_conf_t : public jit_conv_conf_t { + int itiles; + int jtiles; + int ntiles; + int ic_simd_block=16; + int tile_4fma_padding; + int tile_4fma; + int oc_simd_block=16; + int oc_reg_block; + int ic_reg_block; + int tile_block; + int tile_block_ur; + int nb_tile_block_ur; + + bool double_buffering; + bool with_relu_postsum; + int zmm_start; + int nb_reg; + + int dimK; + int dimK_4fma; + int dimK_reg_block; + int dimK_block; + int dimK_nb_block; + + int dimM; + int dimM_reg_block; + int dimM_simd_block; + int dimM_block; + int dimM_nb_block; + + int dimN; + int dimN_reg_block; + int dimN_bcast_ur; + int dimN_block; + int dimN_nb_block; + + winograd_sched_t sched_policy; +}; + +struct jit_conv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *src_prf; + const void *dst_prf; + const void *filt_prf; + const void *bias_prf; + const void *scales; + const void *acc_s32; + const void *compensation; + size_t kd_offset; + size_t kd_offset_prf; + size_t d_index; + size_t d_index_prf; + size_t d_worksize; + size_t d_worksize_prf; + size_t kd_padding; + size_t kd_padding_prf; + size_t kh_padding; + size_t kh_padding_prf; + size_t owb; + size_t owb_prf; + size_t kw_padding; + size_t channel; + size_t channel_prf; + size_t oc_blocks; + size_t ur_w; + size_t ur_str_w; + size_t ch_blocks; + size_t t_overflow; + size_t b_overflow; + int flags; +}; + +struct jit_deconv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *scales; + const void *compensation; + size_t t_overflow; + size_t b_overflow; + size_t kh_padding; + size_t oc_blocks; +}; + +struct jit_dw_conv_call_s { + const void *input; + const void *output; + const void *filter; + const void *bias; + size_t kh_count; + size_t oh_count; + size_t oh_index; + size_t filter_pad_off; + unsigned char + exec_flags; /* Flags passed by driver execution to inner kernel */ +}; + +struct jit_wino_transform_call_s { + size_t tile_block; + size_t tile_block_ur; + size_t nb_tile_block_ur; + size_t tile_count; + size_t tj; + size_t ti; + void *src; + void *dst; + void *Mw; + void *M; + void *T; + void *G; + void *bias; +}; + +struct jit_1x1_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int iw, ih, ow, oh; + int l_pad, t_pad; + int kh, kw; + int stride_h, stride_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int is, os; + int ic_block, oc_block; + + int ur, ur_tail; + + int reduce_dim, reduce_block, nb_reduce, + nb_reduce_blocking, nb_reduce_blocking_max; + int load_dim, load_block, nb_load, + nb_load_blocking, nb_load_blocking_max, nb_load_chunk; + int bcast_dim, bcast_block, nb_bcast, + nb_bcast_blocking, nb_bcast_blocking_max; + + int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step; + int load_loop_load_step, load_loop_iter_step; + int bcast_loop_output_step, bcast_loop_output_substep; + int bcast_loop_bcast_step, bcast_loop_bcast_substep; + int fma_step; + int load_grp_count; + conv_1x1_loop_order_t loop_order; + bool use_vmovntps; + /* avx512 core */ + bool expl_bcast; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* 4fma */ + bool transpose_src; + int tr_is; + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + int is_oc_scale; + data_type_t bia_dt; + data_type_t dst_dt; + bool signed_input; + float wei_adj_scale; +}; + +struct jit_gemm_conv_conf_t { + prop_kind_t prop_kind; + + int mb; + int ngroups, ic, oc; + int iw, ih, id, ow, oh, od; + int l_pad, t_pad, f_pad; + int kh, kw, kd; + int stride_h, stride_w, stride_d; + int dilate_h, dilate_w, dilate_d; + bool with_bias; + + int is, os, ks; + int ic_block, oc_block; + + int nthr; + ptrdiff_t im2col_sz; + bool need_wei_reduction; + bool signed_input; + int oh_block; + int ow_block; + bool outer_threading; +}; + +struct jit_1x1_conv_call_s { + const void *bcast_data; + const void *load_data; + const void *output_data; + const void *bias_data; // used in forward and backward_weights only + const void *acc_s32; + const void *scales; + const void *compensation; + + size_t load_dim; + size_t bcast_dim; + size_t reduce_dim; + + size_t output_stride; // used in backward_weights only + + size_t first_last_flag; +}; + +/* pooling */ +struct jit_pool_conf_t { + int ndims; + int mb, c; + int id, ih, iw, od, oh, ow; + int stride_d, stride_h, stride_w; + int kd, kh, kw; + int f_pad, t_pad, l_pad; + alg_kind_t alg; + bool is_training; + bool pad_w_is_null; + bool is_backward; + bool simple_alg; + data_type_t ind_dt; + + int c_block, c_tail, nb_c; + int ur_c, ur_c_tail; + int ur_w; + int ur_w_tail; + size_t tail[4]; + data_type_t src_dt; + data_type_t dst_dt; +}; + +struct jit_pool_call_s { + const float *src; + const float *dst; + const void *indices; + const float *src_prf; + const float *dst_prf; + const void *indices_prf; + size_t oh; + size_t kd_padding; + size_t kh_padding; + size_t kh_padding_shift; + size_t kd_padding_shift; + size_t kw_padding; + const float* init_value; + float ker_area_h; +}; + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp new file mode 100644 index 0000000000..94d2101d6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp @@ -0,0 +1,677 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto reg_load = [=](int i, int n) { + return Xmm(2*ur * load_loop_blk + 2*i + n + 1); + }; + + auto reg_accum = [=](int i, int j, int n) { + return Xmm(2*j * load_loop_blk + 2*i + n + 1); + }; + + auto bias_ptr = [=](int i, int n) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i, int n) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)]; + }; + + auto output_ptr = [=](int i, int j, int n) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) { + movups(reg_accum(i, j, 0), bias_ptr(i, 0)); + movups(reg_accum(i, j, 1), bias_ptr(i, 1)); + } + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + xorps(r0, r0); + xorps(r1, r1); + } + + L(init_done); + + // load weights + for (int i = 0; i < load_loop_blk; ++i) { + movups(reg_load(i, 0), load_ptr(0, i, 0)); + movups(reg_load(i, 1), load_ptr(0, i, 1)); + } + + movss(reg_bcast, bcast_ptr(0, 0)); + shufps(reg_bcast, reg_bcast, 0); + }; // init() + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + addps(r0, output_ptr(i, j, 0)); + addps(r1, output_ptr(i, j, 1)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(1, + 2 * ur * load_loop_blk + 1); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + movups(output_ptr(i, j, 0), reg_accum(i, j, 0)); + movups(output_ptr(i, j, 1), reg_accum(i, j, 1)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + mulps(reg_load(i, 0), reg_bcast); + mulps(reg_load(i, 1), reg_bcast); + addps(reg_accum(i, j, 0), reg_load(i, 0)); + addps(reg_accum(i, j, 1), reg_load(i, 1)); + + if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) { + movups(reg_load(i, 0), load_ptr(u + 1, i, 0)); + movups(reg_load(i, 1), load_ptr(u + 1, i, 1)); + } + } + if (j < ur - 1) { + movss(reg_bcast, bcast_ptr(u, j + 1)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for ur + if (!last_block || u < jcp.reduce_loop_unroll - 1) { + movss(reg_bcast, bcast_ptr(u + 1, 0)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for reduce_loop_unroll + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} // reduce_loop() + +void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i, int n) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i, int n) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = diff_bias_reg(i, 0); + auto r1 = diff_bias_reg(i, 1); + xorps(r0, r0); + xorps(r1, r1); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) { + movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0)); + movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1)); + } + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) { + addps(diff_bias_reg(i, 0), load_ptr(u, i, 0)); + addps(diff_bias_reg(i, 1), load_ptr(u, i, 1)); + } + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) { + movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0)); + movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1)); + } + + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_sse42_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) + return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 4; + jcp.ic_block = jcp.oc_block = simd_w*2; + + args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ur = 1; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp new file mode 100644 index 0000000000..b314a5098c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_1x1_CONV_KERNEL_F32_HPP +#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_conv_kernel_f32: public jit_generator { + jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_sse42_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32) + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t imm_addr64 = load_loop_iter; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + xmm_t reg_bcast = xmm_t(15); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp new file mode 100644 index 0000000000..30c137641e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp @@ -0,0 +1,134 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_1x1_convolution.hpp" +#include "utils.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void jit_sse42_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int ndims = src_d.ndims(); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + parallel(0, [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto par_conv = jit_1x1_conv_call_s(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + const int bcast_step_rem = jcp.nb_bcast - osb; + int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max + ? bcast_step_rem : jcp.nb_bcast_blocking; + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + const int ow = os % jcp.ow; + const int oh = os / jcp.ow; + const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0); + const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0); + + par_conv.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step_rem = jcp.nb_load - ocb; + const int load_step = load_step_rem < jcp.nb_load_blocking_max + ? load_step_rem : jcp.nb_load_blocking; + + const size_t _ocb = g * nb_oc + ocb; + par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + par_conv.output_data = &dst[dst_off]; + + par_conv.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + par_conv.first_last_flag = 0 + | (icb == 0) * FLAG_REDUCE_FIRST + | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST; + + par_conv.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking * jcp.ic_block); + + const size_t _icb = g * nb_ic + icb; + const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw); + par_conv.bcast_data = &src[src_off]; + + par_conv.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + kernel_->jit_ker(&par_conv); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp new file mode 100644 index 0000000000..b32b1e4784 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_1x1_CONVOLUTION_HPP +#define CPU_JIT_SSE42_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", sse42, ""), + jit_sse42_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_1x1_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_sse42_1x1_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = new jit_sse42_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + } + ~jit_sse42_1x1_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_1x1_conv_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp new file mode 100644 index 0000000000..17cabc1186 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp @@ -0,0 +1,497 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l); + else + inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + + for (int jj = jj_start; jj < jj_end; jj++) + { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * ker_off]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + } +} + +void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l); + else + inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ifm2 * oc_blk; + for (int jj = jj_start; jj < jj_end; jj++) { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw) ? + dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) + ? dilate_h : ic_blk * dilate_h; + const int inp_off = one_of(jcp.src_tag, ncw, nchw) + ? dilate_w : ic_blk * dilate_w; + + xor_(simd_iter, simd_iter); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + Label init_simd_iter_loop; + Label init_done; + Label init_first; + + L(init_simd_iter_loop); + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), xword[reg_output + + sizeof(float) * (ii * oh * ow + jj) * oc_blk]); + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + addps(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1)); + } + + L(init_done); + + Label skip_kh_loop; + mov(kj, reg_kh); + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.with_eltwise) { + Label regular_store; + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(1, oc_blocks * ur_w + 1); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off = (ii * oh * ow + jj) * oc_blk; + + Xmm reg_out = Xmm(ur_w * ii + jj + 1); + movups(xword[reg_output + sizeof(float) * o_off], reg_out); + } + } + + mov(aux_reg_kernel, reg_kernel); + mov(aux_reg_input, reg_input); + add(aux_reg_kernel, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + add(reg_bias, sizeof(float) * 4); + + inc(simd_iter); + cmp(simd_iter, 2); + jl(init_simd_iter_loop, T_NEAR); + + sub(reg_output, sizeof(float) * 8); + sub(reg_bias, sizeof(float) * 8); +} + +inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, oc_blocks); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail" +} + +void jit_sse42_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail); + } + + L(exit); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0]; + jcp.dilate_w = cd.dilates[ndims - 3]; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + bool args_ok = true + && IMPLICATION(flat, one_of(jcp.src_tag, ncw, nwc, nchw, nhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o)) + && IMPLICATION(mimo, one_of(jcp.src_tag, nCw8c, nChw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c); + if (!args_ok) return status::unimplemented; + + const int simd_w = 8; // 2 SSE vectors processing at once + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + // kernel needs 1 temporary YMM register + const int num_avail_regs = 15; + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp new file mode 100644 index 0000000000..33c26ef081 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp @@ -0,0 +1,93 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP +#define JIT_SSE42_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_conv_fwd_kernel_f32: public jit_generator { + jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_sse42_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32) + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t simd_iter = r15; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = reg_oc_blocks; + Xbyak::Reg32 reg_ci_flag = r13d; + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void solve_common(int oc_blocks); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp new file mode 100644 index 0000000000..5f77d692f5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_convolution.hpp" +#include "mkldnn_thread.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() \ + ? (f).blk_off(g, __VA_ARGS__) \ + : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kh, kw) \ + pd()->ndims() == 3 \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : wht_blk_off_(f, g, oc, ic, kh, kw) + +void jit_sse42_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + } + icbb += icb_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp new file mode 100644 index 0000000000..d2f0a38c5c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_CONVOLUTION_HPP +#define CPU_JIT_SSE42_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_sse42_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", sse42, ""), + jit_sse42_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() == 3; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_sse42_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_sse42_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_sse42_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_conv_fwd_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp new file mode 100644 index 0000000000..0e734f7265 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp @@ -0,0 +1,1192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "jit_generator.hpp" +#include "cpu_barrier.hpp" + +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +#define GET_OFF(x) offsetof(ctx_t, x) + +struct jit_trans_iw_ic_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_t) + + jit_trans_iw_ic_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(float), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t k3333 = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kCCCC = k4; + opmask_t k0F0F = k5; + opmask_t kF0F0 = k6; + opmask_t kTail = k7; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_src, + (transpose_size + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + int offset = (transpose_size) * typesize + i * tr_src_stride; + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, offset)); + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, + offset + 64)); + }; + + auto pf_src_t1 = [=](int i) { + if(enable_prefetch) prefetcht1(EVEX_compress_addr(reg_src_prf, + i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if(enable_prefetch) prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, + i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(16 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = r; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) + add(reg_tr_src_tmp, l_pad * typesize); + + if (tail != transpose_size) + kmovw(kTail, (1 << tail) - 1); + + // Xbyak does not allow k0 to be specified explicitly via the '|' + // operator, so we have to do this via a method call (implicitly + // EVEX encoding uses k0 to mean 'no mask') + bool partial_store = nrows < 16; + auto k = partial_store ? kTail : k0; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + if (nontemporal_stores && !partial_store) + vmovntps(addr, r); + else + vmovups(addr, r); + + if (r_pad > 0) { + add(reg_tr_src_tmp, tail * typesize); + padding(reg_tr_src_tmp, r_pad); + } + + if (l_pad > 0) { + padding(reg_tr_src, l_pad); + } + }; + + auto transpose16x8 = [=](int base_idx) { + assert(base_idx == 0 || base_idx == 8); + + // swap 1 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i * 2; + int src_idx1 = src_idx0 + 1; + + int next_src_idx0 = src_idx0 + 2; + int next_src_idx1 = src_idx1 + 2; + bool load_next = base_idx == 0 || i < 3; + + if (base_idx == 0 && i == 0) { + load(src_idx0); + load(src_idx1); + } + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx1); + auto src0 = src_zmm(src_idx0); + auto src1 = src_zmm(src_idx1); + + if (next_src_idx0 < nrows && load_next) + load(next_src_idx0); + valignd(tmp0, src0, src0, 0x1); + pf_src_t1(base_idx + i); + + if (next_src_idx1 < nrows && load_next) + load(next_src_idx1); + valignd(tmp1, src1, src1, 0xf); + pf_src_t0(base_idx + i); + + vmovaps(src0 | kAAAA, tmp1); + vmovaps(src1 | k5555, tmp0); + } + // swap 2 + for (int i = 0; i < 4; i++) { + int select_half = (i < 2) ? 0 : 2; + int src_idx0 = base_idx + i + select_half + 0; + int src_idx2 = src_idx0 + 2; + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx2); + auto src0 = src_zmm(src_idx0); + auto src2 = src_zmm(src_idx2); + + valignd(tmp0, src0, src0, 0x2); + pf_src_t1(base_idx + 4 + i); + valignd(tmp1, src2, src2, 0xe); + pf_src_t0(base_idx + 4 + i); + vmovaps(src2 | k3333, tmp0); + vmovaps(src0 | kCCCC, tmp1); + } + + // swap 4 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i; + int src_idx4 = src_idx0 + 4; + + auto tmp0 = tmp_zmm(src_idx0); + auto src0 = src_zmm(src_idx0); + auto src4 = src_zmm(src_idx4); + + vmovaps(tmp0, src0); + vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); + pf_tr_src_t1(base_idx / 2 + i); + vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); + pf_tr_src_t0(base_idx / 2 + i); + } + }; + + auto fixup16x16 = [=]() { + // swap 8 + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0x44); + store(tmp, i); + if (i % 2 == 0) { + pf_tr_src_t1(8 + i / 2); + pf_tr_src_t0(8 + i / 2); + } + } + + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(8 + i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0xee); + store(tmp, 8 + i); + if (i % 2 == 0) { + pf_tr_src_t1(12 + i / 2); + pf_tr_src_t0(12 + i / 2); + } + } + }; + + transpose16x8(0); + transpose16x8(8); + fixup16x16(); +} + +void jit_trans_iw_ic_t::generate() { + preamble(); + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(k3333, 0x3333); // 0011001100110011 + kmovw(k5555, 0x5555); // 0101010101010101 + kmovw(kAAAA, 0xaaaa); // 1010101010101010 + kmovw(kCCCC, 0xcccc); // 1100110011001100 + kmovw(k0F0F, 0x0f0f); // 0000111100001111 + kmovw(kF0F0, 0xf0f0); // 1111000011110000 + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_ic_int16_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t) + jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf): + jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFFFF = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kAA = k4; + opmask_t k55 = k5; + opmask_t kCC = k6; + opmask_t k33 = k7; + opmask_t kTail = k1; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + Xbyak::Zmm vidx1 = zmm31; + Xbyak::Zmm vidx2 = zmm30; + Xbyak::Zmm vidx3 = zmm29; + Xbyak::Zmm vidx4 = zmm28; + Xbyak::Zmm vidx5 = zmm27; + Xbyak::Zmm zmm_tmp = zmm26; + + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_int16_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto store = [=](Zmm r, int i) { + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = zmm_tmp; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + int store_tail = (nrows%2) ? nrows+1 : nrows; + + int store_pad = (l_pad%2) ? l_pad/2 + 1 : l_pad/2; + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) { + padding(reg_tr_src, store_pad); + add(reg_tr_src_tmp, l_pad * typesize); + } + if (r_pad > 0) { + store_pad = (r_pad%2) ? r_pad/2 + 1 : r_pad/2; + int addr_shift = (r_pad%2) ? 1 : 0; + add(reg_tr_src_tmp, (nrows - addr_shift) * typesize); + padding(reg_tr_src_tmp, store_pad); + } + + mov(reg_tr_src_tmp, reg_tr_src); + add(reg_tr_src_tmp, l_pad * typesize); + + kmovw(kTail, (1 << store_tail/2) - 1); + auto k = kTail; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, r); + + }; + + kmovw(kFFFF, 0xffff); + //all loads + for (int i=0; i<16; i++){ + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // for odd numbers we need to mix row with zeroes + if (nrows%2) { + int i = nrows-1; + auto src0 = src_ymm(i); + auto src1 = src_ymm(i+1); //zero + + auto zmm_src0 = src_zmm(i); + vpxor(src1, src1, src1); + + load_ymm(i); + vpunpckhwd(src0, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src0, 0); + vpxor(src0, src0, src0); + load_ymm(i); + vpunpcklwd(src1, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src1, 1); + vpxord(zmm_src0, zmm_src0, zmm_src0); + vmovups(zmm_src0, zmm_tmp); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // swap 1 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(4*i); + auto zmm1 = src_zmm(4*i+2); + auto tmp0 = src_zmm(4*i+1); + auto tmp1 = src_zmm(4*i+3); + + vmovups(tmp0, zmm0); + vmovups(tmp1, zmm1); + + vpermps(tmp0 | kAAAA, vidx3, zmm1); + vpermps(tmp1 | k5555, vidx3, zmm0); + } + // swap 2 + int base_idx; + base_idx=0; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + base_idx=8; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + + // swap 3 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(2*i); + auto zmm1 = src_zmm(2*i+8); + + auto tmp0 = src_zmm(2*i+1); + auto tmp1 = src_zmm(2*i+9); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kCC, vidx1, zmm1); + vpermpd(tmp1 | k33, vidx1, zmm0); + } + + // all stores + for (int i=0; i<8; i++) + vextracti64x4(src_ymm(2*i), src_zmm(2*i+1), 1); + + store(src_zmm(1), 0); + store(src_zmm(0), 1); + store(src_zmm(3), 2); + store(src_zmm(2), 3); + store(src_zmm(9), 4); + store(src_zmm(8), 5); + store(src_zmm(11), 6); + store(src_zmm(10), 7); + store(src_zmm(5), 8); + store(src_zmm(4), 9); + store(src_zmm(7), 10); + store(src_zmm(6), 11); + store(src_zmm(13), 12); + store(src_zmm(12), 13); + store(src_zmm(15), 14); + store(src_zmm(14), 15); + +} + +void jit_trans_iw_ic_int16_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 2, 3, 0, 1, 6, 7, 4, 5 }; + alignas(64) static constexpr const int64_t idx2[8] + = { 1, 0, 3, 2, 5, 4, 7, 6 }; + alignas(64) static constexpr const int32_t idx3[16] + = { 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 }; + alignas(64) static constexpr const int32_t idx4[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + alignas(64) static constexpr const int32_t idx5[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFFFF, 0xffff); + kmovw(k5555, 0x5555); + kmovw(kAAAA, 0xaaaa); + kmovw(kAA, 0xaa); + kmovw(k55, 0x55); + kmovw(kCC, 0xcc); + kmovw(k33, 0x33); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + vmovdqa64(vidx2, idx2); + vmovdqa32(vidx3, idx3); + vmovdqa32(vidx4, idx4); + vmovdqa32(vidx5, idx5); + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); + +} + +struct jit_trans_ow_oc_t: public jit_trans_dst_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t) + jit_trans_ow_oc_t(const jit_conv_conf_t *conf): jit_trans_dst_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + using zmm = const Xbyak::Zmm; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFF = k1; + + zmm vidx1 = zmm31; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + + auto store = [=](Zmm r, int i) { + auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride); + if (nontemporal_stores) + vmovntps(addr, r); + else + vmovups(addr, r); + }; + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, 2*i); + } + if (r_pad > 0) { + auto src0 = src_ymm(nrows-1); + auto src1 = src_ymm(nrows); + auto zmm_src0 = src_zmm(30); + load_ymm(nrows-1); + + vpxor(src1, src1, src1); + vpunpckhwd(src1, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src1, 0); + vpxor(src1, src1, src1); + vpunpcklwd(src0, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src0, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, nrows-1); + } +} + +void jit_trans_ow_oc_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 4, 5, 0, 1, 6, 7, 2, 3 }; + + const int oc_block = conf_->oc_block; + const int ow = conf_->ow; + const int transposes = utils::div_up(ow, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = ow - loop_iters * transpose_size; + + src_stride = oc_block * typesize; + tr_src_stride = oc_block * typesize; + + bool nontemporal_stores = false; + enable_prefetch = ow > small_spatial ? 1 : 0; + + const int src_step = oc_block * transpose_size * typesize; + const int tr_src_step = oc_block * transpose_size * typesize; + const int right_pad = ow % 2; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFF, 0xFF); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + transpose(tail, 0, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_x4_4x_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_x4_4x_t) + + jit_trans_iw_x4_4x_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + + void generate(); + enum { typesize = (int)sizeof(float) }; +}; + +/** @brief transposition of the form [:][iw/4][4] -> [:][4][iw/4] + * required for 1st 4fma backward by weights convolution */ +void jit_trans_iw_x4_4x_t::generate() { + using namespace utils; + + /* TODO: put into code */ + static int mask[16] = { + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, }; + + const auto &c = *conf_; + const int simd_w = cpu_isa_traits::vlen / typesize; + const int niters = c.tr_ld / simd_w; + + assert(niters <= 4); /* [bwd_w:tr_src:r1] */ + + Reg64 reg_ptr_src = r8; + Reg64 reg_ptr_tr_src = r9; + + Reg64 reg_ih = rax; + Reg64 reg_ih_end = rbx; + + Reg64 reg_nthr_oc_b = rsi; + Reg64 reg_ptr_tr_src_bctx = abi_not_param1; + + Reg64 reg_tmp = rdx; + + Zmm vmsk = Zmm(31); + Opmask kmsk = k7; + + auto emit_tr_sync = [&]() { + simple_barrier::generate(*this, reg_ptr_tr_src_bctx, reg_nthr_oc_b); + }; + + auto emit_tr_iw = [&]() { + auto vreg = [](int iter, int i) { + assert(4 * iter + i < 24); + return Zmm(4 * iter + i); + }; + auto vtmp = [](int i) { return Zmm(24 + i); }; + + auto emit_load = [&](int iter) { + for (int i = 0; i < 4; ++i) { + auto v = vreg(iter, i); + const int off = (iter * 4 + i) * simd_w; + + if (off + simd_w <= c.iw) + vmovups(v, ptr[reg_ptr_src + off * typesize]); + else if (off < c.iw) + vmovups(v | kmsk | T_z, ptr[reg_ptr_src + off * typesize]); + else + vpxord(v, v, v); + } + }; + + auto emit_tr = [&](int iter) { + for (int i = 0; i < 4; ++i) + vpermps(vreg(iter, i), vmsk, vreg(iter, i)); + + vshuff32x4(vtmp(0), vreg(iter, 0), vreg(iter, 1), 0x88); + vshuff32x4(vtmp(1), vreg(iter, 0), vreg(iter, 1), 0xdd); + vshuff32x4(vtmp(2), vreg(iter, 2), vreg(iter, 3), 0x88); + vshuff32x4(vtmp(3), vreg(iter, 2), vreg(iter, 3), 0xdd); + + vshuff32x4(vreg(iter, 0), vtmp(0), vtmp(2), 0x88); + vshuff32x4(vreg(iter, 2), vtmp(0), vtmp(2), 0xdd); + vshuff32x4(vreg(iter, 1), vtmp(1), vtmp(3), 0x88); + vshuff32x4(vreg(iter, 3), vtmp(1), vtmp(3), 0xdd); + }; + + auto emit_store = [&]() { + for (int i = 0; i < 4; ++i) { + for (int iter = 0; iter < niters; ++iter) { + const size_t off = i * c.tr_ld + iter * simd_w; + vmovups(ptr[reg_ptr_tr_src + off * typesize], vreg(iter, i)); + } + } + }; + + for (int iter = 0; iter < niters; ++iter) + emit_load(iter); + + for (int iter = 0; iter < niters; ++iter) + emit_tr(iter); + + emit_store(); + }; + + preamble(); + + mov(reg_ptr_src, ptr[abi_param1 + GET_OFF(src)]); + mov(reg_ptr_tr_src, ptr[abi_param1 + GET_OFF(tr_src)]); + + mov(reg_nthr_oc_b.cvt32(), ptr[abi_param1 + GET_OFF(nthr_oc_b)]); + mov(reg_ih.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_start)]); + mov(reg_ih_end.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_end)]); + mov(reg_ptr_tr_src_bctx, ptr[abi_param1 + GET_OFF(tr_src_bctx)]); + + emit_tr_sync(); + + Label l_ih_loop, l_tr_done; + cmp(reg_ih, reg_ih_end); + je(l_tr_done, T_NEAR); + + mov(reg_tmp, (size_t)&mask[0]); + vmovups(vmsk, ptr[reg_tmp]); + + if (c.iw % simd_w) { + const char load_mask = (1 << (c.iw % simd_w)) - 1; + mov(reg_tmp, load_mask); + kmovw(kmsk, reg_tmp.cvt32()); + } + + /* src += ih_start * c.iw; */ + imul(reg_tmp, reg_ih, c.iw * typesize); + add(reg_ptr_src, reg_tmp); + /* tr_src += ih_start * c.stride_w * c.tr_ld; */ + imul(reg_tmp, reg_ih, c.stride_w * c.tr_ld * typesize); + add(reg_ptr_tr_src, reg_tmp); + + L(l_ih_loop); { + emit_tr_iw(); + + add(reg_ptr_src, c.iw * typesize); + add(reg_ptr_tr_src, c.stride_w * c.tr_ld * typesize); + + inc(reg_ih); + cmp(reg_ih, reg_ih_end); + jl(l_ih_loop, T_NEAR); + } + + L(l_tr_done); + + emit_tr_sync(); + + postamble(); +} + +/* +// ------------------------------------------------- +// jit_transpose4x16_src +// ------------------------------------------------- +*/ + +void jit_transpose4x16_src::transpose(int nrows) +{ + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 4, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if (tparams->src_pf0_distance) + prefetcht0(EVEX_compress_addr( + reg_src, (tparams->src_pf0_distance + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + if (tparams->tr_src_pf0_distance) + prefetcht0(EVEX_compress_addr(reg_tr_src, + (tparams->tr_src_pf0_distance + i) * src_stride)); + }; + + auto pf_src_t1 = [=](int i) { + if (tparams->src_pf1) + prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if (tparams->tr_src_pf1) + prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(4 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r); + }; + + auto tmp0 = tmp_zmm(0); + auto tmp1 = tmp_zmm(1); + auto tmp2 = tmp_zmm(2); + auto tmp3 = tmp_zmm(3); + + auto src0 = src_zmm(0); + auto src1 = src_zmm(1); + auto src2 = src_zmm(2); + auto src3 = src_zmm(3); + for (int i = 0; i < nrows; i++) { + load(i); + } + + for (size_t i = nrows; i < 4; i++) { + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + vmovupd(tmp0, src0); + vmovupd(tmp1, src1); + pf_src_t0(0); + vpermpd(tmp0 | kF0, vidx01, src2); + vpermpd(tmp1 | kF0, vidx01, src3); + + valignd(src0, src0, src0, 8); + valignd(src1, src1, src1, 8); + pf_src_t0(1); + vmovupd(tmp2, src0); + vmovupd(tmp3, src1); + pf_src_t0(2); + vpermpd(tmp2 | kF0, vidx10, src2); + vpermpd(tmp3 | kF0, vidx10, src3); + pf_src_t0(3); + + vmovupd(src0, tmp0); + pf_src_t1(0); + vmovupd(src1, tmp2); + pf_src_t1(1); + vmovupd(src2, tmp1); + pf_src_t1(2); + vmovupd(src3, tmp3); + pf_src_t1(3); + vpermpd(src0 | kCC, vidx1, tmp1); + vpermpd(src1 | kCC, vidx1, tmp3); + pf_tr_src_t0(0); + vpermpd(src2 | k33, vidx1, tmp0); + vpermpd(src3 | k33, vidx1, tmp2); + pf_tr_src_t0(1); + + vmovupd(tmp0, src0); + vmovupd(tmp1, src2); + pf_tr_src_t0(2); + vmovupd(tmp2, src1); + vmovupd(tmp3, src3); + pf_tr_src_t0(3); + vpermps(tmp0 | kFFFF, vidxP, src0); + pf_tr_src_t1(0); + vpermps(tmp1 | kFFFF, vidxP, src2); + pf_tr_src_t1(1); + vpermps(tmp2 | kFFFF, vidxP, src1); + pf_tr_src_t1(3); + vpermps(tmp3 | kFFFF, vidxP, src3); + pf_tr_src_t1(4); + + store(tmp0, 0); + store(tmp1, 1); + store(tmp2, 2); + store(tmp3, 3); +} + +alignas(64) static constexpr const int64_t idx01[8] + = { 0, 0, 0, 0, 0, 1, 2, 3 }; +alignas(64) static constexpr const int64_t idx10[8] + = { 0, 0, 0, 0, 4, 5, 6, 7 }; +alignas(64) static constexpr const int64_t idx1[8] = { 2, 3, 0, 1, 6, 7, 4, 5 }; +alignas(64) static constexpr const int32_t idxP[16] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + +void jit_transpose4x16_src::generate() +{ + preamble(); + + const int ic_block = params->ic_block; + const int is = params->is; + int tail = is % transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = ic_block * typesize; + + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * transpose_size * typesize; + +#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x) + mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]); + mov(reg_src, ptr[param1 + GET_TR_OFF(src)]); + mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]); + mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]); +#undef GET_TR_OFF + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + kmovw(kF0, 0xf0); // 11110000 + kmovw(kCC, 0xcc); // 11001100 + kmovw(k33, 0x33); // 00110011 + kmovw(kFFFF, 0xffff); // 1111111111111111 + + vmovdqa64(vidx01, idx01); + vmovdqa64(vidx10, idx10); + vmovdqa64(vidx1, idx1); + vmovdqa32(vidxP, idxP); + + Label loop_label; + Label tail_label; + + cmp(reg_loop, transpose_size); + jl(tail_label, T_NEAR); + + L(loop_label); + { + transpose(transpose_size); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, transpose_size); + cmp(reg_loop, transpose_size); + jge(loop_label, T_NEAR); + } + L(tail_label); + transpose(tail); + + postamble(); +} + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) { + if (conf->ver == ver_4fma && !conf->is_1stconv) + return new jit_trans_iw_ic_t(conf); + if (conf->ver == ver_4fma && conf->is_1stconv) + return new jit_trans_iw_x4_4x_t(conf); + assert(!"unsupported configuration"); + return nullptr; +} + +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) { + assert(!"unsupported configuration"); + return nullptr; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp new file mode 100644 index 0000000000..565e97e4fc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp @@ -0,0 +1,145 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_TRANSPOSE_SRC_HPP +#define CPU_JIT_TRANSPOSE_SRC_HPP + +#include "cpu_barrier.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_trans_src_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_src_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_src_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_src_transpose_s { + size_t size; + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; +}; + +struct jit_trans_dst_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_dst_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_dst_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_transpose4x16_src_t { + int src_pf0_distance; + int tr_src_pf0_distance; + bool src_pf1; + bool tr_src_pf1; +}; + +struct jit_transpose4x16_src : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) + + jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, + jit_transpose4x16_src_t *tparams_) + : params(aparams), tparams(tparams_) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + const jit_1x1_conv_conf_t *params; + const jit_transpose4x16_src_t *tparams; + void (*jit_ker)(jit_src_transpose_s *); + + void operator()(jit_src_transpose_s *arg) { jit_ker(arg); } + + static const int transpose_size = 4; +private: + static const int typesize = sizeof(float); + + int src_stride, tr_src_stride; + + Xbyak::Reg64 imm_addr64 = rbx; + + Xbyak::Opmask kF0 = k1; + Xbyak::Opmask kCC = k2; + Xbyak::Opmask k33 = k3; + Xbyak::Opmask kFFFF = k4; + + Xbyak::Zmm vidx01 = zmm31; + Xbyak::Zmm vidx10 = zmm30; + Xbyak::Zmm vidx1 = zmm29; + Xbyak::Zmm vidxP = zmm28; + + Xbyak::Reg64 reg_src = r8; + Xbyak::Reg64 reg_tr_src = r9; + Xbyak::Reg64 reg_src_prf = r10; + Xbyak::Reg64 reg_tr_src_prf = r11; + Xbyak::Reg64 reg_loop = r12; + Xbyak::Reg64 reg_tr_src_tmp = r13; + Xbyak::Reg32 regw_tmp = r14d; + + void transpose_block(int ur, int nrows); + void transpose(int nrows); + void generate(); +}; + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp new file mode 100644 index 0000000000..53313f9f01 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp @@ -0,0 +1,327 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_1x1_CONV_UTILS_HPP +#define JIT_UNI_1x1_CONV_UTILS_HPP + +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +struct reduce_to_unit_stride_t { + convolution_desc_t conv_d_; + bool reduce_src_; + size_t space_per_thread_; +}; + +/* 1x1-kernel does not support non-unit strides so far, so the idea is: + * - for fwd or bwd_weights: to copy src to a scratch memory (with strides + * equal to 1) and then call the kernel + * - for bwd_data: reduce the problem to the one with unit stride by + * performing computations in a scratch memory (with strides equal to 1) + * and then copy the result to diff_src */ +template +inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d, + const memory_desc_t *&src_d, const memory_desc_t *dst_d) { + const bool is_bwd_data = self->desc()->prop_kind + == prop_kind::backward_data; + + const int ndims = src_d->ndims; + const auto dat_tag = ndims == 3 + ? memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nCw8c, format_tag::nCw16c) + : memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nChw8c, format_tag::nChw16c); + + bool rtus_applicable = true + && utils::pick(ndims - 3, + (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type, + data_type::s32)), + (conv_d->strides[0] != 1 || conv_d->strides[1] != 1)) + && dat_tag != format_tag::undef; + for (int d = 2; d < ndims; ++d) { + /* TODO: relax these conditions (by improving reducer) */ + rtus_applicable = rtus_applicable + && conv_d->padding[0][d - 2] == 0 + && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d]; + } + + if (rtus_applicable) { + self->rtus_.reduce_src_ = true; + conv_d = &(self->rtus_.conv_d_ = *conv_d); + self->rtus_.conv_d_.strides[0] = 1; + if (ndims == 4) + self->rtus_.conv_d_.strides[1] = 1; + utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2); + if (ndims == 4) + utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2); + const int ic = src_d->dims[1]; + if (is_bwd_data) { + src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d); + self->rtus_.conv_d_.diff_src_desc.dims[1] = ic; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.diff_src_desc, dat_tag); + } else { + data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type; + src_d = &(self->rtus_.conv_d_.src_desc = *dst_d); + self->rtus_.conv_d_.src_desc.dims[1] = ic; + self->rtus_.conv_d_.src_desc.data_type = data_type; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.src_desc, dat_tag); + } + } +} + +template +inline void rtus_prepare_space_info(conv_pd_t *self, + memory_tracking::registrar_t &scratchpad) { + const auto &jcp = self->jcp_; + + const int max_threads = mkldnn_get_max_threads(); + const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind, + jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking); + size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->desc())->data_type); + + self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block; + scratchpad.book(memory_tracking::names::key_conv_rtus_space, + typesize * max_threads * self->rtus_.space_per_thread_); +} + +template +struct rtus_driver_t: public jit_generator { + + struct call_params_t { + const void *ws; /* reduced image (w/ strides = 1) */ + const void *src; /* source image (w/ non-unit strides) */ + size_t icb; + size_t os; + size_t iw_start; + }; + + void (*ker_)(const call_params_t *p); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + + Xbyak::Reg64 reg_ws = abi_param1; + Xbyak::Reg64 reg_src = abi_not_param1; + Xbyak::Reg64 reg_icb = rdx; + Xbyak::Reg64 reg_os = r11; + Xbyak::Reg64 reg_iw_start = r8; + + Xbyak::Reg64 reg_cur_os = rax; + Xbyak::Reg64 reg_cur_iw = r9; + Xbyak::Reg64 reg_cur_src = r10; + + int iw_, stride_w_; + int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_; + bool src_to_ws_; + size_t typesize_; + Vmm reg_zero; + Vmm reg_v; + + rtus_driver_t(int iw, int stride_w, int src_step_h, + int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize) + : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h) + , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb) + , src_to_ws_(src_to_ws), typesize_(typesize) + { + using namespace Xbyak; + vlen_ = cpu_isa_traits::vlen; + vlen_shift_ = cpu_isa_traits::vlen_shift; + if (typesize_ == 2) { + vlen_ /= 2; + vlen_shift_--; + } + + reg_zero = Vmm(0); + reg_v = Vmm(1); + + generate(); + } + + void loop_is() { + using namespace Xbyak; + + mov(reg_cur_src, reg_src); + mov(reg_cur_iw, reg_iw_start); + mov(reg_cur_os, reg_os); + + Label is_loop, skip_h_step; + L(is_loop); + + if (src_to_ws_) { + vmovups(reg_v, ptr[reg_cur_src]); + vmovups(ptr[reg_ws], reg_v); + } else { + vmovups(reg_v, ptr[reg_ws]); + vmovups(ptr[reg_cur_src], reg_v); + for (int w = 1; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + } + + add(reg_ws, vlen_); + + add(reg_cur_iw, stride_w_); + add(reg_cur_src, stride_w_ * vlen_); + + cmp(reg_cur_iw, iw_); + jl(skip_h_step); + /* for 1d convolution the loop over h should be skipped */ + if (src_step_icb_ == iw_) jmp(skip_h_step); + + if (src_to_ws_) { + add(reg_cur_src, (src_step_h_ - iw_) * vlen_); + } else { + Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */ + mov(reg_cur_src_fin, reg_cur_src); + add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_); + Label ih_loop; + L(ih_loop); + + for (int w = 0; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + + add(reg_cur_src, stride_w_ * vlen_); + cmp(reg_cur_src, reg_cur_src_fin); + jl(ih_loop); + } + xor_(reg_cur_iw, reg_cur_iw); + + L(skip_h_step); + + sub(reg_cur_os, vlen_); + jnz(is_loop); + + /* restore dst */ + sub(reg_ws, reg_os); + } + + void generate() { + using namespace Xbyak; + assert(isa == avx2 || isa == avx512_common + || isa == avx512_core || isa == avx512_mic); + +#if defined(_WIN32) + assert(reg_src == abi_not_param1 && abi_not_param1 == rdi); + push(rdi); +#endif + +#define READ_PARAM(what) \ + mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)]) + READ_PARAM(src); + READ_PARAM(icb); + READ_PARAM(os); + READ_PARAM(iw_start); + + assert(reg_ws == abi_param1); + READ_PARAM(ws); /* reg_ws should always be read the last */ +#undef READ_PARAM + + shl(reg_os, vlen_shift_); + + if (!src_to_ws_) + uni_vpxor(reg_zero, reg_zero, reg_zero); + + Label icb_loop; + L(icb_loop); + + loop_is(); + + add(reg_ws, ws_step_icb_ * vlen_); + add(reg_src, src_step_icb_ * vlen_); + + dec(reg_icb); + jnz(icb_loop, T_NEAR); + +#if defined(_WIN32) + pop(rdi); +#endif + + uni_vzeroupper(); + ret(); + this->ker_ = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +template +inline void init_rtus_driver(conv_t *self) { + const auto &conf = *self->pd(); + if (!conf.rtus_.reduce_src_) return; + + const auto &cd = *conf.desc(); + const int ndims = conf.ndims(); + const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0]; + const int stride_w = cd.strides[ndims - 3]; + + const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data; + const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md(); + + const int ih = ndims == 3 ? 1 : src_d.dims[2]; + const int iw = src_d.dims[ndims - 1]; + + const int src_step_h = stride_h * iw; + const int src_step_icb = ih * iw; + const int ws_step_icb = conf.jcp_.is; + const bool src_to_ws = !is_bwd_data; + const size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->pd()->desc())->data_type); + + self->rtus_driver_ = new rtus_driver_t(iw, stride_w, src_step_h, + src_step_icb, ws_step_icb, src_to_ws, typesize); +} + +inline int best_divider(int value, int min_divider, int max_divider, + bool find_max, int step = 1) +{ + max_divider = nstl::max(1, nstl::min(max_divider, value)); + min_divider = nstl::max(1, nstl::min(min_divider, max_divider)); + + auto loss_ratio = [](int total, int chunk) + { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); }; + + float min_loss = FLT_MAX; + int x_divider = max_divider; + for (int divider = max_divider; divider >= min_divider; divider -= step) { + const float loss = loss_ratio(value, divider); + if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) { + min_loss = loss; + x_divider = divider; + } + } + return x_divider; +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp new file mode 100644 index 0000000000..72fe3a8109 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp @@ -0,0 +1,1407 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "jit_uni_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace memory_tracking::names; + +using namespace Xbyak; +namespace barrier = simple_barrier; + +typedef float data_t; + +template +struct jit_bnorm_t: public jit_generator { + struct call_params_t { + // keep all sizes at 8 bytes -- jit code expects this + size_t N_ithr, N_nthr; + size_t coff_max, soff_max; + size_t mb_stride_Bc, spat_size, spat_size_loc; + size_t S_s, S_tail; + size_t is_cblk_tail; + data_t chan_size, eps, one; + const data_t *scale_shift; + const data_t *mean, *var; + const data_t *diff_scale_shift; + const data_t *src, *dst; + const data_t *diff_src, *diff_dst; + const data_t *rbuf1, *rbuf2; + const uint8_t *ws; + barrier::ctx_t *barrier; + }; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional3::type; + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx2) ? yword : zword; + + const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; + + const batch_normalization_pd_t *bdesc_; + bool is_spatial_thr_; + + void (*ker)(const call_params_t *); + void operator()(const call_params_t *p) { (*ker)(p); } + + Reg64 reg_param = abi_param1; + + Reg64 reg_scale_shift = rbx; + Reg64 reg_rbuf1 = abi_not_param1; + Reg64 reg_rbuf2 = rdx; + + Reg64 reg_mean = rbp; + Reg64 reg_var = reg_param; + Reg64 reg_diff_scale_shift = rax; + + Reg64 reg_coff = r8; + Reg64 reg_coff_max = r9; + Reg64 reg_soff = r10; + Reg64 reg_soff_max = r11; + Reg64 reg_ctr = r12; + Reg64 reg_roff = r13; + + Reg64 reg_mb_stride_Bc = r14; + + Reg64 reg_src = r15; + Reg64 reg_diff_src = reg_rbuf1; + Reg64 reg_dst = rsi; + Reg64 reg_diff_dst = reg_dst; + + Reg64 reg_tmp_off = reg_roff; + + // Reuse loop counters + Reg64 reg_bar = reg_coff; + Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff + Reg64 reg_tmp = reg_ctr; + + // Relu section + bool with_relu, with_relu_inf_only; + Vmm vzero; // is_fwd() ? vdiff_beta : vbeta + Reg64 reg_ws = reg_roff; + Label l_relu_mask_avx2; + Opmask kstore_mask = Opmask(1); + + // channel tail processing + Opmask ktail_mask = Opmask(2); + + size_t unroll_blocks; + size_t unroll_regs; + Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5); + Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6); + Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7); + Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8); + Vmm vone = Vmm(isa == avx512_common ? 24 : 9); + Vmm vmean = Vmm(isa == avx512_common ? 25 : 10); + Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11); + Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12); + Vmm veps = Vmm(isa == avx512_common ? 28 : 13); + Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14); + Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15); + + size_t t0_pf_offt; + size_t t1_pf_offt; + size_t spat_size; + size_t chan_data_offt; + + enum { + stack_off_N_nthr = 0, + stack_off_N_ithr = 8, + stack_off_src = 16, + stack_off_dst = 24, + stack_off_diff_src = 32, + stack_off_diff_dst = 40, + stack_off_diff_scale_shift = 48, + stack_off_ws = 56, + stack_off_barrier = 64, + stack_off_spat_size_loc = 72, + stack_off_s_s = 80, + stack_off_s_tail = 88, + stack_off_is_cblk_tail = 96, + stack_size_required = 104, + }; + + bool is_c_padded() const { + const memory_desc_wrapper data_d(bdesc_->src_md()); + return bdesc_->C() != data_d.padded_dims()[1]; + } + + void compute_static_strides() { + spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H(); + chan_data_offt = bdesc_->C() * sizeof(data_t); + + if (isa == avx512_mic) { + t0_pf_offt = 4096; + t1_pf_offt = 0; + } else { + t0_pf_offt = 0; + t1_pf_offt = 0; + } + } + + void load_common_params() { +# define PARAM_OFF(x) offsetof(call_params_t, x) + mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]); + if (bdesc_->is_bwd()) + mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]); + mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]); + mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]); + mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]); + shl(reg_coff_max, 2); + shl(reg_soff_max, 2); + shl(reg_mb_stride_Bc, 2); + + mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); + mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]); + + uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]); + uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]); + uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); + + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]); + mov(ptr[rsp + stack_off_N_nthr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]); + mov(ptr[rsp + stack_off_N_ithr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]); + mov(ptr[rsp + stack_off_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]); + mov(ptr[rsp + stack_off_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]); + mov(ptr[rsp + stack_off_diff_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]); + mov(ptr[rsp + stack_off_diff_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]); + mov(ptr[rsp + stack_off_ws], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]); + mov(ptr[rsp + stack_off_barrier], reg_tmp); + if (is_spatial_thr_) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]); + mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]); + mov(ptr[rsp + stack_off_s_s], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]); + mov(ptr[rsp + stack_off_s_tail], reg_tmp); + } + if (is_c_padded()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]); + mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp); + } + + if (bdesc_->is_fwd()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } else { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]); + mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } +# undef PARAM_OFF + } + + void prepare_tail_mask_avx512_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + const int mask = (1 << tail) - 1; + + Reg32 regw_tmp = reg_tmp.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + void prepare_tail_mask_avx2_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0, 0, 0, 0, 0, 0, 0, 0}; + + mov(reg_tmp, reinterpret_cast(&mask[8 - tail])); + vmovups(vtail_mask, ptr[reg_tmp]); + } + + void prepare_relu() { + with_relu = bdesc_->is_fwd() + ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu() + : bdesc_->fuse_bn_relu(); + with_relu_inf_only = with_relu && bdesc_->is_fwd() + && !(bdesc_->fuse_bn_relu() && bdesc_->is_training()); + + vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta; + if (with_relu) { + uni_vpxor(vzero, vzero, vzero); + if (!bdesc_->is_fwd() && isa == avx2) + prepare_l_relu_mask_avx2(); + } + } + + void prepare_l_relu_mask_avx2() { + Label l_mask_after; + jmp(l_mask_after); + align(32); + L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */ + for (int i = 0; i < 8; ++i) dd(1< + void spat_loop(size_t len, size_t blocks, size_t regs, + init_t init, body_t body, fini_t fini) { + size_t factor = regs * blocks; + size_t loop_unroll = len / factor * factor; + size_t loop_tail = len - loop_unroll; + size_t num_active_regs = (len < regs) ? len : regs; + for (size_t i = 0; i < num_active_regs; i++) + init(i); + if (loop_unroll) { + if (is_spatial_thr_) { + mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); + add(reg_soff, ptr[rsp + stack_off_s_s]); + } else { + mov(reg_ctr, loop_unroll); + } + Label label; + L(label); { + for (size_t i = 0; i < factor; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + add(reg_soff, factor * vlen); + sub(reg_ctr, factor); + jnz(label); + } + if (is_spatial_thr_) { + add(reg_soff, ptr[rsp + stack_off_s_tail]); + } + } + + for (size_t i = 0; i < loop_tail; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + if (loop_tail) + add(reg_soff, loop_tail * vlen); + + for (size_t i = 0; i < num_active_regs; i++) + fini(i); + } + + void mean_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, + unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v0 = Vmm(base_reg * 2 + 0); + Vmm v1 = Vmm(base_reg * 2 + 1); + size_t offt = i * vlen; + uni_vmovups(v1, + vmmword[reg_src + reg_soff + offt]); + uni_vaddps(v0, v0, v1); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void var_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 3); + if (base_reg > 0) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(3 * base_reg); + Vmm vtmp0 = Vmm(3 * base_reg + 1); + Vmm vtmp1 = Vmm(3 * base_reg + 2); + size_t offt = i * vlen; + uni_vmovups(vtmp0, + vmmword[reg_src + reg_soff + offt]); + if (isa == sse42) { + movups(vtmp1, vmean); + subps(vtmp1, vtmp0); + } else { + vsubps(vtmp1, vmean, vtmp0); + } + uni_vfmadd231ps(v, vtmp1, vtmp1); + + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 3); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void compute_mean_variance() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf; + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + + xor_(reg_soff, reg_soff); + Label mean_spatial; + L(mean_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + mean_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + mean_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(mean_spatial); + } + + Label no_mean_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_mean_reduction); + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label mean_reduction_channels; + L(mean_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label mean_reduction_thrs; + L(mean_reduction_thrs); { + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0)); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(mean_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(mean_ptr(), Vmm(1)); + + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(mean_reduction_channels); + } + } + L(no_mean_reduction); + barrier(); + + xor_(reg_soff, reg_soff); + Label var_spatial; + L(var_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + var_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + var_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(var_spatial); + } + + Label no_var_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_var_reduction); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label var_reduction_channels; + L(var_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label var_reduction_thrs; + L(var_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(var_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(var_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(var_reduction_channels); + } + } + L(no_var_reduction); + barrier(); + } + + void forward_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + + if (bdesc_->use_scaleshift()) { + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vbeta, beta_ptr()); + } + + Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone; + Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar; + + if (isa == sse42) { + movups(vbuf, vscale); + divps(vbuf, vsqrtvar); + movups(vdiv, vbuf); + } else { + vdivps(vdiv, vscale, vsqrtvar); + } + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(base_reg); + size_t offt = i * vlen; + uni_vmovups(v, + vmmword[reg_src + reg_soff + offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + uni_vsubps(v, v, vmean); + if (bdesc_->use_scaleshift()) { + uni_vfmadd213ps(v, vgamma, vbeta); + } else { + uni_vmulps(v, v, vsqrtvar); + } + if (with_relu_inf_only) { + uni_vmaxps(v, v, vzero); + } else if (with_relu) { + if (isa == avx512_common) + fwd_process_relu_avx512_common(v, offt); + else + fwd_process_relu_avx2(v, offt, Vmm(3)); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_dst + reg_soff + offt], v); + } else { + uni_vmovups( + vmmword[reg_dst + reg_soff + offt], v); + } + }, + [](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_dst, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void forward() { + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_dst, ptr[rsp + stack_off_dst]); + mov(reg_ws, ptr[rsp + stack_off_ws]); + + xor_(reg_soff, reg_soff); + Label dst_spatial; + L(dst_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + forward_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + add(reg_dst, vlen / 2); + mov(reg_coff, vlen / 2); + + forward_channels(); + + sub(reg_src, vlen / 2); + sub(reg_dst, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jnz(dst_spatial); + } + } + + void backward_sh_channels() { + Label sh_channels; + L(sh_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]); + spat_loop(spat_size, 1, 1, + [=](size_t base_reg) { + if (base_reg > 0) { + for (int i = 0; i < 2; i++) { + Vmm v(base_reg * 5 + i); + uni_vpxor(v, v, v); + } + } + }, + [=](size_t base_reg, size_t i) { + Vmm o0 = Vmm(base_reg * 5 + 0); + Vmm o1 = Vmm(base_reg * 5 + 1); + Vmm t1 = Vmm(base_reg * 5 + 2); + Vmm t2 = Vmm(base_reg * 5 + 3); + Vmm t3 = Vmm(base_reg * 5 + 4); + size_t offt = i * vlen; + uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]); + uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(t2, offt); + else if (isa == avx2) + bwd_process_relu_avx2(t2, offt, t3); + else + assert(false); + } + uni_vsubps(t3, vmean, t1, t3); + if (isa == sse42) { + mulps(t3, t2); + subps(o0, t3); + } else { + vfnmadd231ps(o0, t3, t2); + } + uni_vaddps(o1, o1, t2); + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt + + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b0 = Vmm(0); + Vmm b1 = Vmm(1); + if (base_reg) { + uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0)); + uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1)); + } + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(sh_channels); + } + } + + void backward_diff_channels() { + Label diff_channels; + L(diff_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + if (bdesc_->use_scaleshift()) + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr()); + uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar); + uni_vdivps(vdiff_beta, vdiff_beta, vchan_size); + uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size); + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v(base_reg * 2 + 0); + Vmm t(base_reg * 2 + 1); + Vmm t1(base_reg * 2 + 2); + size_t offt = i * vlen; + uni_vmovups(v, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(v, offt); + else if (isa == avx2) + bwd_process_relu_avx2(v, offt, t); + else + assert(false); + } + if (!bdesc_->use_global_stats()) { + uni_vsubps(v, v, vdiff_beta); + uni_vmovups(t, vmmword[reg_src + reg_soff + + offt]); + uni_vsubps(t, vmean, t, t1); + uni_vmulps(t, t, vdiff_gamma); + uni_vaddps(v, v, t); + } + uni_vmulps(v, v, vsqrtvar); + if (bdesc_->use_scaleshift()) { + uni_vmulps(v, v, vgamma); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_diff_src + reg_soff + offt], + v); + } else { + uni_vmovups( + vmmword[reg_diff_src + reg_soff + offt], + v); + } + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + + offt + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_diff_src, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(diff_channels); + } + } + + void backward() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf, sh_spatial; + + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + L(sh_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_sh_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_sh_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(sh_spatial); + } + + mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]); + + Label no_sh_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + Label sh_reduction_channels; + jne(no_sh_reduction, T_NEAR); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + L(sh_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + mov(reg_ctr, reg_nnthr); + Label sh_reduction_thrs; + L(sh_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]); + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(sh_reduction_thrs); + } + uni_vmulps(Vmm(0), Vmm(0), vsqrtvar); + uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0)); + uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(sh_reduction_channels); + } + } + L(no_sh_reduction); + barrier(); + + mov(reg_diff_src, ptr[rsp + stack_off_diff_src]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + Label diff_spatial; + L(diff_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_diff_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_diff_src, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_diff_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_diff_src, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(diff_spatial); + } + } + + jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) { + static_assert(isa == sse42 || isa == avx2 || isa == avx512_common + || isa == avx512_mic, "unsupported isa"); + + const int simd_w = isa == sse42 ? 8 : + cpu_isa_traits::vlen / sizeof(data_t); + is_spatial_thr_ = + bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t)); + + unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + + preamble(); + + if (isa == avx512_common) + prepare_tail_mask_avx512_common(); + else if (isa == avx2) + prepare_tail_mask_avx2_common(); + + compute_static_strides(); + sub(rsp, stack_size_required); + load_common_params(); + prepare_relu(); + + if (bdesc_->is_fwd()) { + if (!bdesc_->stats_is_src()) { + compute_mean_variance(); + } + forward(); + } else { + backward(); + } + add(rsp, stack_size_required); + postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +template +struct uni_bnorm_driver_t: public c_compatible { + uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc) + : bdesc_(bdesc), ker_(bdesc_) + { + const int nthrs = mkldnn_get_max_threads(); + const dim_t C_PADDED = get_c_padded(bdesc_); + + size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED + * bdesc_->D() * bdesc_->H() * bdesc_->W(); + l3_size_ = get_cache_size(3, true) * nthrs / 2; + do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0); + } + + ~uni_bnorm_driver_t() {} + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const batch_normalization_pd_t *bdesc) { + int nthrs = mkldnn_get_max_threads(); + dim_t C_PADDED = get_c_padded(bdesc); + + int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED; + int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED; + int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs; + + scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz); + + if (mkldnn_thr_syncable()) { + int n_barriers = C_PADDED / simd_w; + scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers); + } + } + + void exec(int ithr, int nthr, const data_t *src, data_t *diff_src, + data_t *dst, const data_t *diff_dst, const data_t *scale_shift, + data_t *diff_scale_shift, const data_t *mean, const data_t *var, + const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) { + auto sbuf = scratchpad.get(key_bnorm_tmp_stats); + auto pbuf = scratchpad.get(key_bnorm_tmp_diff_ss); + auto rbuf = scratchpad.get(key_bnorm_reduction); + auto barriers = scratchpad.get(key_barrier); + + dim_t N = bdesc_->MB(); + dim_t C = bdesc_->C(); + dim_t C_PADDED = get_c_padded(bdesc_); + dim_t D = bdesc_->D(); + dim_t H = bdesc_->H(); + dim_t W = bdesc_->W(); + dim_t SP = D * H * W; + dim_t img_size = C_PADDED * D * H * W; + const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; + + typename jit_bnorm_t::call_params_t p; + + p.eps = bdesc_->desc()->batch_norm_epsilon; + p.one = 1.0f; + p.spat_size = D * H * W; + p.chan_size = 1.0f * N * p.spat_size; + + dim_t C_blks = C_PADDED / simd_w; + + int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0}; + dim_t C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0}; + + dim_t C_blks_per_iter{ 1 }; + int64_t iters{ 1 }; + if (do_blocking_) { + int num_tensors = bdesc_->is_fwd() ? 1 : 2; + size_t working_set_size + = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors; + bnorm_utils::cache_balance(working_set_size, C_blks, + C_blks_per_iter, iters); + } + + bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks, + SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, + S_ithr, S_nthr, S_s, S_e); + + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1)); + + p.N_ithr = SP_N_ithr; + p.N_nthr = SP_N_nthr; + + int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter; + int global_C_blk_s; + int global_barriers_per_iter = C_nthr; + + for (int64_t it = 0; it < iters; it++) { + if (it == iters - 1 && iters > 1) { + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + + // Update call parameters for JIT, last iteration + p.N_ithr = N_ithr * S_nthr + S_ithr; + p.N_nthr = N_nthr * S_nthr; + } + + global_C_blk_s = do_blocking_ ? + (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s : + C_blk_s; + + int C_blks_thr = C_blk_e - C_blk_s; + int N_thr = N_e - N_s; + + size_t coff_base = global_C_blk_s * simd_w; + size_t soff_base + = global_C_blk_s * p.spat_size * simd_w + N_s * img_size; + + p.spat_size_loc = S_e - S_s; + p.S_s = S_s * vlen; + p.S_tail = (p.spat_size - S_e) * vlen; + p.coff_max = C_blks_thr * simd_w; + p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base; + p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base; + p.scale_shift = scale_shift + coff_base; + p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_) + ? pbuf : diff_scale_shift) + coff_base; + + p.soff_max = N_thr * img_size; + p.src = src + soff_base; + p.dst = dst + soff_base; + p.diff_src = diff_src + soff_base; + p.diff_dst = diff_dst + soff_base; + p.ws = ws + soff_base / 8; + + p.mb_stride_Bc = img_size - p.coff_max * p.spat_size; + + // use SP_N_nthr which is the same as p.N_nthr except maybe for + // the last iteration. + p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr + + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w; + // rbuf1 and rbuf2 have to be disjoint + p.rbuf2 = p.rbuf1 + C_PADDED * nthr; + p.is_cblk_tail = (it * C_blks_per_iter + C_blk_e) * simd_w > C; + + size_t iter_bariers + = do_blocking_ ? it * global_barriers_per_iter : 0; + p.barrier = barriers + C_ithr + iter_bariers; + if (p.soff_max != 0 && p.coff_max != 0) + ker_(&p); + } + } + + void init_barriers(const memory_tracking::grantor_t &scratchpad) { + auto barriers = scratchpad.get(key_barrier); + if (barriers) { + const int n_barriers = get_c_padded(bdesc_) / simd_w; + for (int i = 0; i < n_barriers; ++i) + barrier::ctx_init(&barriers[i]); + } + } + +private: + enum { + simd_w = isa == sse42 ? 8 : cpu_isa_traits::vlen / sizeof(data_t) + }; + + static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) { + return true + && !bdesc->stats_is_src() + && bdesc->desc()->prop_kind == prop_kind::forward_inference; + } + + static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc) + { + return false + || (bdesc->is_bwd() && !bdesc->use_scaleshift()) + || bdesc->desc()->prop_kind == prop_kind::backward_data; + } + + static dim_t get_c_padded(const batch_normalization_pd_t *bdesc) + { return bdesc->src_md()->padded_dims[1]; } + + const batch_normalization_pd_t *bdesc_; + bool do_blocking_; + size_t l3_size_; + + jit_bnorm_t ker_; +}; + +} + +using namespace data_type; +using namespace format_tag; +using namespace utils; + +/* fwd */ + +template +status_t jit_uni_batch_normalization_fwd_t::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + } + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t::init_scratchpad(scratchpad, this); + + return status::success; +} + +template +jit_uni_batch_normalization_fwd_t::jit_uni_batch_normalization_fwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } + +template +status_t jit_uni_batch_normalization_fwd_t::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + auto var = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr, + scale_shift, nullptr, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template +jit_uni_batch_normalization_fwd_t::~jit_uni_batch_normalization_fwd_t() +{ delete bnorm_driver_; } + +/* bwd */ + +template +status_t jit_uni_batch_normalization_bwd_t::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c + : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c; + + bool ok = true + && mayiuse(isa) + && is_bwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && everyone_is(f32, src_md()->data_type, diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + if (fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + /* TODO: extra checks required */ + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t::init_scratchpad(scratchpad, this); + + return status::success; +} + +template +jit_uni_batch_normalization_bwd_t::jit_uni_batch_normalization_bwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } + +template +status_t jit_uni_batch_normalization_bwd_t::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto var = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scale_shift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst, + scale_shift, diff_scale_shift, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template +jit_uni_batch_normalization_bwd_t::~jit_uni_batch_normalization_bwd_t() +{ delete bnorm_driver_; } + +/* struct instantiation */ +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp new file mode 100644 index 0000000000..96410ec84e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp @@ -0,0 +1,100 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_BATCH_NORMALIZATION_HPP +#define JIT_UNI_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_isa_traits.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { template struct uni_bnorm_driver_t; } + +template +struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_fwd_t); + + status_t init(); + }; + + typedef typename prec_traits::type data_t; + + jit_uni_batch_normalization_fwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t *bnorm_driver_; +}; + +template +struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_bwd_t); + + status_t init(); + }; + + typedef typename prec_traits::type data_t; + + jit_uni_batch_normalization_bwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_bwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t *bnorm_driver_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp new file mode 100644 index 0000000000..b7dc5f85c5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp @@ -0,0 +1,1302 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +template +void jit_uni_dw_conv_fwd_kernel_f32::load_src(int ur_ch_blocks, int ur_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + int b_off = ch*jcp.ch_block + i*4; + if (this->jcp.with_bias) + uni_vmovups(vmm_acc, + vmmword[reg_bias + b_off*sizeof(float)]); + else + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + + int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block + + ow*jcp.ch_block + i*4; + if (this->jcp.with_sum) + uni_vaddps(vmm_acc, vmm_acc, + vmmword[reg_output + o_off*sizeof(float)]); + } + } + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(iter_kw, reg_kw); + mov(aux1_reg_input, aux_reg_input); + mov(aux1_reg_kernel, aux_reg_kernel); + + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + i*4; + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + add(aux1_reg_kernel, ch_blk*sizeof(float)); + add(aux1_reg_input, ch_blk*dilate_w*sizeof(float)); + + dec(iter_kw); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_filter_unrolled( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int kw = 0; kw < jcp.kw; kw++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4; + + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + } + + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_activation( + int ur_ch_blocks, int ur_w) { + if (this->jcp.with_eltwise) { + int repeats = isa == sse42 ? 2 : 1; + eltwise_injector_->compute_vector_range(4, repeats * ur_w * ur_ch_blocks + 4); + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::store_dst( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4; + Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst); + } + } + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::loop_body(int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter_unrolled(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template +bool jit_uni_dw_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +template +status_t jit_uni_dw_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.oc / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc_without_padding != jcp.oc) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +template struct jit_uni_dw_conv_fwd_kernel_f32; +template struct jit_uni_dw_conv_fwd_kernel_f32; +template struct jit_uni_dw_conv_fwd_kernel_f32; + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::load_ddst( + int ur_ch_blocks, int ur_str_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_str_w) { + int kw = jcp.kw; + int kh = jcp.kh; + int ow = jcp.ow; + int oh = jcp.oh; + + int ch_blk = jcp.ch_block; + int stride_h = jcp.stride_h; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(aux1_reg_ddst, aux_reg_ddst); + mov(aux1_reg_kernel, aux_reg_kernel); + + mov(iter_kw, reg_kw); + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*kh*kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int w = 0; w < ur_str_w; w++) { + int ddst_off = (ch*oh*ow + w)*ch_blk + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_ddst + + ddst_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + + add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float)); + sub(aux1_reg_ddst, ch_blk*sizeof(float)); + + sub(iter_kw, stride_w); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + + add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float)); + sub(aux_reg_ddst, ow*ch_blk*sizeof(float)); + + sub(iter_kh, stride_h); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::store_dsrc( + int ur_ch_blocks, int ur_str_w) { + int ch_blk = jcp.ch_block; + int iw = jcp.iw; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4; + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + + uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::loop_body( + int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_str_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_str_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_bwd_data_kernel_f32::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); +} + +template +status_t jit_uni_dw_conv_bwd_data_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1]; + + jcp.ih = diff_src_d.dims()[2]; + jcp.iw = diff_src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1 + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.ic / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template +void jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +template struct jit_uni_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_dw_conv_bwd_data_kernel_f32; + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_filter() { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < jcp.kw; ++i) { + Vmm vmm_acc = get_acc_reg(r * jcp.kw + i); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (reg_set + i) * simd_w; + Vmm vmm_acc = get_acc_reg(reg_set + i); + uni_vmovups(vmm_acc, + vmmword[reg_tmp_filter + off_filter * sizeof(float)]); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vpxor(vmm_bias, vmm_bias, vmm_bias); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int iw_block = ow_block * jcp.stride_w; + const int right_border = jcp.iw - iw_block; + + const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); + + /* preamble count for number of cascaded LOAD + FMA operation */ + const int input_overlap = nstl::max(jcp.kw - l_pad, 0); + + /* LOAD initial input registers, then cascade LOADs and FMAs*/ + for (int r = 0; r < reg_repeats; ++r) { + for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { + int off_output = (i_ur * reg_repeats + r) * simd_w; + Vmm vmm_output = get_output_reg(r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + if (i_ur == 0) { + for (int c = 0; c < input_overlap; ++c) { + int off_input + = ((c - pad_offset) * reg_repeats + r) * simd_w; + Vmm vmm_input + = get_input_reg((c % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } else { + for (int c = 0; c < cascade_input; ++c) { + int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; + int off_input + = ((overlap + c - pad_offset) * reg_repeats + r) + * simd_w; + Vmm vmm_input = get_input_reg( + ((overlap + c) % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } + + for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { + int io_overlap = i_kw + (i_ur * jcp.stride_w); + + /* Don't apply FMAs that fall into the padded region */ + if (io_overlap - l_pad < 0 + || io_overlap - jcp.l_pad >= right_border) + continue; + + Vmm vmm_input = get_input_reg( + ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r); + Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r); + Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input; + if (isa == sse42) + uni_vmovups(vmm_aux, vmm_input); + uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output); + } + } + } +} + +template +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_step_unroll( + const int unroll_w) { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < unroll_w; ++i) { + Vmm vmm_bias = get_bias_reg(r); + int off_output = (i * reg_repeats + r) * simd_w; + if (isa == sse42) { + /* Need to support unaligned address loads for SSE42*/ + Vmm vmm_output = get_output_reg(1 + r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + uni_vaddps(vmm_bias, vmm_bias, vmm_output); + } else { + uni_vaddps(vmm_bias, vmm_bias, + vmmword[reg_tmp_output + off_output * sizeof(float)]); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (i + reg_set) * simd_w; + Vmm vmm_acc = get_acc_reg(i + reg_set); + uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)], + vmm_acc); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_loop( + const int block_size) { + Label oh_label; + Label ow_blk_label; + + const int unroll_w = nstl::min(block_size, jcp.ow); + const int unroll_w_trips = jcp.ow / unroll_w; + const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0; + + const int ch_offset = jcp.ch_block; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + L(oh_label); + { + + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + { + + compute_bias_step_unroll(unroll_w); + add(reg_tmp_output, unroll_w * ch_offset * sizeof(float)); + + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + if (tail_w > 0) { + compute_bias_step_unroll(tail_w); + add(reg_tmp_output, tail_w * ch_offset * sizeof(float)); + } + + inc(reg_oh); + cmp(reg_oh, reg_oh_worksize); + jl(oh_label, T_NEAR); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_zero_filter() { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_zeroing_label; + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_FILTER); + test(reg_exec_flags, reg_exec_flags); + je(skip_zeroing_label); + + zero_filter(); + + mov(reg_tmp_filter, reg_filter_baddr); + mov(reg_kh, jcp.kh); + L(kh_loop_label); + { + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float)); + + L(skip_zeroing_label); +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_loop_label; + + cmp(reg_kh_count, 0); + je(skip_loop_label, T_NEAR); + + mov(reg_kh, reg_kh_count); + L(kh_loop_label); + { + load_filter(); + compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block); + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + Label kh_comeback_label; + mov(reg_kh, reg_kh_count); + L(kh_comeback_label); + { + sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_comeback_label, T_NEAR); + } + + L(skip_loop_label); +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ? + jcp.ih / jcp.stride_h - 1 : + jcp.oh - jcp.b_pad - 1; + const int ch_offset = jcp.ch_block; + const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + + Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label, + end_h_loop_label; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + mov(reg_kh_count, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + mov(reg_tmp_input, reg_input_baddr); + mov(reg_tmp_filter, reg_filter_baddr); + + L(h_loop_label); + { + + compute_h_step(unroll_w, l_pad, pad_offset, ow_block); + + add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float)); + + /* If within the top_pad region */ + if (jcp.t_pad > 0) { + /* Skip t_pad area if no longer in initial h_block */ + cmp(reg_oh, jcp.t_pad); + jg(skip_tpad_label, T_NEAR); + + cmp(reg_kh_count, jcp.kh); + jge(skip_tpad_label, T_NEAR); + + add(reg_kh_count, t_overlap_off); + sub(reg_tmp_filter, + t_overlap_off * jcp.kw * ch_offset * sizeof(float)); + + /* kernel has moved beyond padding (adjust for stride effects) */ + if (jcp.t_pad % jcp.stride_h != 0) { + int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h; + add(reg_tmp_input, + inp_corr * jcp.iw * ch_offset * sizeof(float)); + } + jmp(tpad_loop_label, T_NEAR); + } + + L(skip_tpad_label); + + cmp(reg_oh, io_overlap); + jl(skip_bpad_label, T_NEAR); + sub(reg_kh_count, b_overlap_off); + + L(skip_bpad_label); + add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float)); + + L(tpad_loop_label); + + cmp(reg_oh, jcp.ih / jcp.stride_h); + jge(end_h_loop_label, T_NEAR); + + inc(reg_oh); + + cmp(reg_oh, reg_oh_worksize); + jl(h_loop_label, T_NEAR); + } + L(end_h_loop_label); +} + +template +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_block_unroll() { + + const int ch_offset = jcp.ch_block; + int ow = jcp.ow; + int pad_offset = 0; + int l_pad = jcp.l_pad; + + /* Calculate effective padding */ + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + + /* Is this strictly defined by: + * -code-size (?) + * -address size (?) */ + const int max_unroll_w = 30; + const int block_size = 15; + + int unroll_w_tail = 0; + int unroll_w = 0; + int unroll_w_trips = 0; + + if (jcp.ow > max_unroll_w) { + unroll_w = nstl::min(block_size, jcp.ow); + unroll_w_trips = ow / unroll_w; + /* calculate tail */ + unroll_w_tail = ow % unroll_w; + /* Perform some rebalancing if tail too small*/ + if ((unroll_w_tail == 0 && r_pad != 0) + || (r_pad > 0 && r_pad >= unroll_w_tail)) { + if (unroll_w_trips > 1) { + unroll_w_tail += unroll_w; + unroll_w_trips--; + } else { + /* Idealy, this case shouldn't happen */ + unroll_w_tail += (unroll_w - unroll_w / 2); + unroll_w = unroll_w / 2; + } + } + } else { + unroll_w = jcp.ow; + unroll_w_trips = nstl::max(1, ow / unroll_w); + } + if (jcp.with_bias) { + Label skip_load_bias; + mov(reg_bias_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]); + + zero_bias(); + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_BIAS); + test(reg_exec_flags, reg_exec_flags); + jne(skip_load_bias); + + load_bias(); + + L(skip_load_bias); + compute_bias_loop(block_size); + + store_bias(); + } + + /* Pass filter address, then offset for h_padding. */ + compute_zero_filter(); + mov(reg_kh_offset, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]); + add(reg_filter_baddr, reg_kh_offset); + + /* compute left padded block */ + if (l_pad) { + compute_h_loop(unroll_w, l_pad, 0, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + unroll_w_trips--; + pad_offset = l_pad; + l_pad = 0; + } + + /* compute middle block */ + Label ow_blk_label; + + /* Insert loop for 'ow' block when middle block needs to execute more + * than once */ + bool do_ow_blk_loop = unroll_w_trips > 1; + if (do_ow_blk_loop) { + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + } + if (unroll_w_trips > 0) { + compute_h_loop(unroll_w, l_pad, pad_offset, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + } + if (do_ow_blk_loop) { + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + /* compute right padded block */ + if (unroll_w_tail) { + compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail); + } +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::generate() { + preamble(); + + mov(reg_input_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]); + mov(reg_output_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]); + mov(reg_filter_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]); + + compute_ow_block_unroll(); + + this->postamble(); +} + +template +status_t jit_uni_dw_conv_bwd_weights_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads) { + if (!mayiuse(isa)) + return status::unimplemented; + + jcp.ngroups = diff_weights_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic); + + if (!jcp.is_depthwise) + return status::unimplemented; + + jcp.ch_block = isa == avx512_common ? 16 : 8; + + jcp.mb = src_d.dims()[0]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = diff_weights_d.dims()[3]; + jcp.kw = diff_weights_d.dims()[4]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.kw <= 3 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ch = jcp.ngroups / jcp.ch_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_hpad = (jcp.kh - 1 + 1) / 2; + const int max_wpad = (jcp.kw - 1 + 1) / 2; + const bool boundaries_ok = true && jcp.t_pad <= max_hpad + && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad + && jcp.r_pad <= max_wpad; + if (!boundaries_ok) + return status::unimplemented; + + balance(jcp, nthreads); + + return status::success; +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + /* Notes: if splitting thread work on 'mb', then a reduction has to take + * place. Hence, book a per-thread, local weights-buffer for the + * reduction */ + if (jcp.nthr_mb > 1) { + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * wei_size * (jcp.nthr_mb - 1)); + + if (jcp.with_bias) + scratchpad.book(key_conv_bia_reduction, + sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1)); + } +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::balance(jit_conv_conf_t &jcp, + int nthreads) { + jcp.nthr = nthreads; + jcp.nthr_g = jcp.nthr_mb = 1; + + /* Basic-Heuristics for parallel strategy: + * 1) Tries to parallel on the number of Groups (g) where tasks are + * independent. Otherwise, + * 2) Tries to split the work across g and MiniBatch (mb). + * Parallelizing on mb requires computing a reduction for weights. + * + * NOTE: because of 'task partitioning' scheme, there will be unbalanced + * per-thread load when the number of threads is high (e.g. > 16). + */ + jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr); + jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb); + + jcp.nthr = jcp.nthr_g * jcp.nthr_mb; +} + +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp new file mode 100644 index 0000000000..9c08fc4a09 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp @@ -0,0 +1,253 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP +#define JIT_UNI_DW_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) + + jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_uni_dw_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + const Xbyak::AddressFrame &vmmword = (isa == sse42) + ? xword : (isa == avx2) ? yword : zword; + const int vlen = cpu_isa_traits::vlen; + + // dw convolution + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t aux1_reg_input = r10; + reg64_t reg_kernel = r11; + reg64_t aux_reg_kernel = r12; + reg64_t aux1_reg_kernel = r13; + reg64_t reg_output = r14; + reg64_t reg_bias = r15; + reg64_t reg_kh = rax; + reg64_t reg_kw = rbx; + reg64_t iter_kh = rdx; + reg64_t iter_kw = rsi; + reg64_t reg_ur_w = rbp; + reg64_t reg_ch_blocks = aux1_reg_input; + reg64_t imm_addr64 = aux1_reg_input; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + inline void load_src(int ur_ch_blocks, int ur_w); + inline void apply_filter(int ur_ch_blocks, int ur_w); + inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w); + inline void apply_activation(int ur_ch_blocks, int ur_w); + inline void store_dst(int ur_ch_blocks, int ur_w); + inline void loop_body(int ur_ch_blocks); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate(); +}; + +template +struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) + + jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t aux1_reg_ddst = abi_not_param1; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t aux1_reg_kernel = rbp; + reg64_t reg_dsrc = rsi; + + reg64_t reg_ur_str_w = r9; + reg64_t reg_ch_blocks = rbx; + + reg64_t iter_kh = r11; + reg64_t iter_kw = r12; + reg64_t reg_kh = r13; + reg64_t reg_kw = r14; + + inline void loop_body(int ur_ch_blocks); + inline void load_ddst(int ur_ch_blocks, int ur_str_w); + inline void apply_filter(int ur_ch_blocks, int ur_str_w); + inline void store_dsrc(int ur_ch_blocks, int ur_str_w); + + void generate(); +}; + +template +struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32) + + jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + static void balance(jit_conv_conf_t &jcp, int nthreads); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_dw_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int reg_repeats = (isa == sse42) ? 2 : 1; + + const Xbyak::AddressFrame &vmmword + = (isa == sse42) ? xword : (isa == avx2) ? yword : zword; + + /* XXX: offset between input and accummulators is 3, therefore, assume 'kw' + * is no larger than 3*/ + inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); } + inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); } + inline Vmm get_aux_reg() { return Vmm(0); } + + reg64_t reg_tmp_input = r9; + reg64_t reg_tmp_output = r10; + reg64_t reg_tmp_filter = r13; + reg64_t reg_kh_offset = rax; + + /* parameter passed by driver into kernel */ + Xbyak::Reg8 reg_exec_flags = bl; + + reg64_t reg_oh_worksize = r14; + reg64_t reg_oh = rax; + + reg64_t iter_ow_blk = r11; + + reg64_t reg_kh = rsi; + reg64_t reg_kh_count = rdx; + + /* Base addresses for convolution parameters. */ + reg64_t reg_input_baddr = r15; + reg64_t reg_output_baddr = r12; + reg64_t reg_filter_baddr = abi_not_param1; + reg64_t reg_bias_baddr = r13; + + /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs + */ + inline void compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ + inline void compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block); + inline void compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* Write 'width' micro-kernel JITs; depending on the padding and convolution + * size, write a micro-kernel for the left ow-block, middle ow-block(s), and + * right ow-block.*/ + inline void compute_ow_block_unroll(); + + inline void compute_zero_filter(); + inline void load_filter(); + inline void zero_filter(); + inline void load_bias(); + inline void zero_bias(); + inline void compute_bias_step_unroll(const int unroll_w); + inline void compute_bias_loop(const int block_size); + inline void store_filter(); + inline void store_bias(); + + void generate(); +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp new file mode 100644 index 0000000000..58449601a3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp @@ -0,0 +1,427 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "jit_uni_dw_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template +void _jit_uni_dw_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + if (pd()->wants_padded_bias()) { + auto padded_bias = this->scratchpad(ctx).template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + int dil_h = jcp.dilate_h + 1; + int dil_w = jcp.dilate_w + 1; + int str_h = jcp.stride_h; + int str_w = jcp.stride_w; + + auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh, + int kh_padding, int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w)); + const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w + + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw; + + const int iw = nstl::max((ow*str_w - jcp.l_pad + + div_up(i_l_overflow, dil_w)*dil_w), 0); + const int kw = div_up(i_l_overflow, dil_w); + + const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w) + - div_up(i_r_overflow, dil_w); + + par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)]; + + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)]; + if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)]; + + par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); + par_conv.kw_padding = (size_t)nstl::max(0, kw_padding); + + par_conv.ur_w = (size_t)ur_w_step; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.oh, + [&](int n, int chb, int oh) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h)); + const int i_b_overflow = nstl::max(jcp.ih, + (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih; + + const int ih = nstl::max((int)(oh*str_h - jcp.t_pad + + div_up(i_t_overflow, dil_h)*dil_h), 0); + const int kh = div_up(i_t_overflow, dil_h); + const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h) + - div_up(i_b_overflow, dil_h); + + // left border + int ow = 0; + int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow); + int ur_w_step = 1; + for (; ow < l_border; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1) + / jcp.stride_w - ow + 1; + if (ur_w_step > 0) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + ow += ur_w_step; + } + + // right border + ur_w_step = 1; + for (; ow < jcp.ow; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template struct _jit_uni_dw_convolution_fwd_t; +template struct _jit_uni_dw_convolution_fwd_t; +template struct _jit_uni_dw_convolution_fwd_t; + +template +void _jit_uni_dw_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, + int i_t_overflow, int i_b_overflow, int stride_off_h, + int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad)); + const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) + - jcp.r_pad)); + + int ow = iw + jcp.l_pad - i_r_overflow; + int stride_off_w = ow % jcp.stride_w; + ow /= jcp.stride_w; + + par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)]; + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow + + stride_off_h, i_r_overflow + stride_off_w)]; + + par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow + - stride_off_h); + par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow + - stride_off_w); + + par_conv.ur_str_w = ur_str_w; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.ih, + [&](int n, int chb, int ih) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih + - jcp.t_pad)); + const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1 + - (jcp.ih - 1 - ih) - jcp.b_pad)); + + int oh = ih + jcp.t_pad - i_b_overflow; + int stride_off_h = oh % jcp.stride_h; + oh /= jcp.stride_h; + + for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) { + // left border + int iw = i_str_w; + int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw); + int ur_str_w = 1; + for (; iw < l_border; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw) + / jcp.stride_w, jcp.iw); + if (ur_str_w > 0) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + iw += ur_str_w * jcp.stride_w; + } + + // right border + ur_str_w = 1; + for (; iw < jcp.iw; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + } + }); +} + +template struct _jit_uni_dw_convolution_bwd_data_t; +template struct _jit_uni_dw_convolution_bwd_data_t; +template struct _jit_uni_dw_convolution_bwd_data_t; + +template +_jit_uni_dw_convolution_bwd_weights_t:: +_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr) +{ + kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32(pd()->jcp_); + if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction()) + acc_ker_ = new cpu_accumulator_1d_t(); +} + +template +void _jit_uni_dw_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto diff_wei_reduction_buf = + scratchpad(ctx).template get(key_conv_wei_reduction); + auto diff_bia_reduction_buf = + scratchpad(ctx).template get(key_conv_bia_reduction); + + const auto &jcp = kernel_->jcp; + + /* Used when executing a parallel reduction */ + simple_barrier::ctx_t reduction_bctx; + simple_barrier::ctx_init(&reduction_bctx); + + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0; + + const int ch_block = jcp.ch_block; + + auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params, + const int batch, const int group, const int oh_start, + const int work_size, const unsigned char exec_flag, + const size_t kh_padding, const size_t filter_off) { + const int tpad_underflow_off = jcp.t_pad - filter_off; + + conv_params->exec_flags = exec_flag; + conv_params->kh_count = jcp.kh - kh_padding; + + const int oh_s = oh_start; + const int oh_e = oh_start + work_size; + const int ih_s = oh_s * jcp.stride_h; + + conv_params->filter_pad_off + = filter_off * jcp.kw * ch_block * sizeof(float); + conv_params->oh_index = oh_s; + conv_params->oh_count = oh_e; + + size_t diff_dst_off + = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh + + oh_start) + * jcp.ow; + + size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih + + ih_s - tpad_underflow_off) * jcp.iw; + + conv_params->output = &diff_dst[diff_dst_off * ch_block]; + conv_params->input = &src[src_off * ch_block]; + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + + auto conv_params = jit_dw_conv_call_s(); + const int h_block_size = 15; + + /* assign iteration space to thread */ + const int ithr_g = ithr % jcp.nthr_g; + const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb; + + /* split dimensions */ + int g_start{ 0 }, g_end{ 0 }; + balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end); + + int mb_start{ 0 }, mb_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end); + + auto diff_wei = ithr_mb == 0 + ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size; + auto diff_bia = ithr_mb == 0 + ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size; + + for (int g = g_start; g < g_end; ++g) { + unsigned char zero_filter_flag = FLAG_ZERO_FILTER; + unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0; + + size_t diff_wei_off = g * jcp.kh * jcp.kw; + conv_params.filter = &diff_wei[diff_wei_off * ch_block]; + + if (jcp.with_bias) + conv_params.bias = &diff_bia[g * ch_block]; + + for (int mb = mb_start; mb < mb_end; ++mb) { + int oh = 0; + while (oh < jcp.oh) { + const int h_work = nstl::min(h_block_size, jcp.oh - oh); + auto kh_t_padding = nstl::max(0, jcp.t_pad - oh); + auto kh_b_padding + = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ? + jcp.b_pad - (h_work - 1) : + 0; + + set_kernel_params(&conv_params, mb, g, oh, h_work, + zero_filter_flag | zero_bias_flag, + kh_t_padding + kh_b_padding, kh_t_padding); + kernel_->jit_ker(&conv_params); + + zero_bias_flag &= ~FLAG_ZERO_BIAS; + zero_filter_flag &= ~FLAG_ZERO_FILTER; + oh += h_work; + } + } + } + + if (do_parallel_reduction() && jcp.nthr_mb > 1) { + size_t reduct_start{ 0 }, reduct_end{ 0 }; + balance211(wei_size, nthr, ithr, reduct_start, reduct_end); + + const int acc_size = reduct_end - reduct_start; + const size_t reduct_off = reduct_start; + auto *acc_data = diff_weights + reduct_off; + + simple_barrier::barrier(&reduction_bctx, nthr); + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + auto *src_data = diff_wei_reduction_buf + + (thr_mb - 1) * wei_size + reduct_off; + acc_ker_->accumulate(acc_data, src_data, acc_size); + } + } + }); + + if (jcp.nthr_mb <= 1) return; + + /* Apply single-threaded 'mb' reduction */ + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + size_t mb_accum_offset = (thr_mb - 1) * wei_size; + size_t b_accum_offset = (thr_mb - 1) * bias_size; + + for (int g = 0; g < jcp.nb_ch; ++g) { + /* Reduction on Bias */ + if (jcp.with_bias) { + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + size_t bias_offset = g * ch_block + g_block; + diff_bias[bias_offset] += diff_bia_reduction_buf[ + b_accum_offset + bias_offset]; + } + } + + if (do_parallel_reduction()) continue; + + for (int kh = 0; kh < jcp.kh; ++kh) + for (int kw = 0; kw < jcp.kw; ++kw) + { + size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw; + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + const size_t off = wei_offset * ch_block + g_block; + diff_weights[off] += + diff_wei_reduction_buf[mb_accum_offset + off]; + } + } + } + } +} + +template struct _jit_uni_dw_convolution_bwd_weights_t; +template struct _jit_uni_dw_convolution_bwd_weights_t; +template struct _jit_uni_dw_convolution_bwd_weights_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp new file mode 100644 index 0000000000..ca53749ec2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP +#define CPU_JIT_UNI_DW_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_uni_dw_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_fwd_kernel_f32::init_conf( + jcp_, *desc(), src_md(), *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_fwd_kernel_f32(pd()->jcp_); } + + ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_fwd_kernel_f32 *kernel_; +}; + +using jit_avx512_common_dw_convolution_fwd_t = + _jit_uni_dw_convolution_fwd_t; +using jit_avx2_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; +using jit_sse42_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; + +template +struct _jit_uni_dw_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_bwd_data_kernel_f32:: + init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_data_kernel_f32 *kernel_; +}; + +using jit_avx512_common_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; +using jit_avx2_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; +using jit_sse42_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; + +template +struct _jit_uni_dw_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const int max_threads = mkldnn_in_parallel() + ? 1 : mkldnn_get_max_threads(); + + status_t status = jit_uni_dw_conv_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md(), max_threads); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); + ~_jit_uni_dw_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + bool do_parallel_reduction() const { return false; } + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_weights_kernel_f32 *kernel_; + cpu_accumulator_1d_t *acc_ker_; +}; + +using jit_avx512_common_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; +using jit_avx2_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; +using jit_sse42_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp new file mode 100644 index 0000000000..2af6435871 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp @@ -0,0 +1,1142 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_eltwise.hpp" + +#define GET_OFF(field) offsetof(jit_args, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template +void jit_uni_eltwise_injector_f32::injector_preamble(size_t start_idx, + size_t end_idx) { + preserved_vecs_count = 0; + vecs_to_preserve = (size_t)aux_vecs_count(alg_); + start_idx_tail = start_idx; + + // For sse42 mask register has to be Xmm(0) + if (isa == sse42 && vecs_to_preserve > 0) { + size_t idx = 0; + assert(idx < start_idx); + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) { + if (preserved_vecs_count >= vecs_to_preserve) break; + if (start_idx <= idx && idx < end_idx) continue; + + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; + for (size_t i = 0; i < preserved_vecs_count_tail; i++) { + preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++; + } + + assert(preserved_vecs_count == vecs_to_preserve); + + if (save_state_) { + h->push(p_table); + + if (preserved_vecs_count) + h->sub(h->rsp, preserved_vecs_count * vlen); + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[i])); + + load_table_addr(); + } + + assign_regs(); +} + +template +void jit_uni_eltwise_injector_f32::injector_preamble_tail(size_t start_idx) +{ + size_t tail_vecs_to_preserve = start_idx_tail - start_idx; + if (tail_vecs_to_preserve == 0) return; + + const int idx_off = vecs_to_preserve - tail_vecs_to_preserve; + + if (save_state_) { + if (idx_off) + h->add(h->rsp, idx_off * vlen); + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), + h->ptr[h->rsp + i * vlen]); + } + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; + + if (save_state_) { + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[idx_off + i])); + + if (idx_off) + h->sub(h->rsp, idx_off * vlen); + } + + assign_regs(); +} + +template +void jit_uni_eltwise_injector_f32::injector_postamble() { + if (!save_state_) return; + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[i]), + h->ptr[h->rsp + i * vlen]); + + if (preserved_vecs_count) + h->add(h->rsp, preserved_vecs_count * vlen); + + h->pop(p_table); +} + +template +void jit_uni_eltwise_injector_f32::assign_regs() { + vmm_mask = Vmm(preserved_vec_idxs[0]); + vmm_aux0 = Vmm(preserved_vec_idxs[0]); + vmm_aux1 = Vmm(preserved_vec_idxs[1]); + vmm_aux2 = Vmm(preserved_vec_idxs[2]); + vmm_aux3 = Vmm(preserved_vec_idxs[3]); + vmm_aux4 = Vmm(preserved_vec_idxs[4]); +} + +template +void jit_uni_eltwise_injector_f32::exp_compute_vector(const Vmm &vmm_src) { + h->uni_vminps(vmm_src, vmm_src, table_val(10)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(11)); + h->uni_vmovups(vmm_aux0, vmm_src); + //calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux1, vmm_aux1); + + h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux1, vmm_src, _op_floor); + } + + //keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx + + //x = x - fx * ln2 + h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3)); + + // compute 2^n + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx + + // y = p5 + h->uni_vmovups(vmm_src, table_val(9)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q) + // y = y * 2^n + h->uni_vmulps(vmm_src, vmm_src, vmm_aux1); +} + +template +void jit_uni_eltwise_injector_f32::relu_compute_vector(const Vmm &vmm_src) +{ + const int alpha_off = 0, zero_off = 1; + + h->uni_vmovups(vmm_aux1, vmm_src); + if (isa == sse42) { + h->movups(vmm_mask, vmm_src); + h->mulps(vmm_src, table_val(alpha_off)); + h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us); + h->blendvps(vmm_src, vmm_aux1); + } else if (isa == avx2) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off)); + h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } else if (isa == avx512_common) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } +} + +template +void jit_uni_eltwise_injector_f32::relu_zero_ns_compute_vector( + const Vmm &vmm_src) { + const int zero_off = 1; + h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off)); +} + +template +void jit_uni_eltwise_injector_f32::elu_compute_vector(const Vmm &vmm_src) { + const int alpha_off = 23, zero_off = 24; + + // compute exponent + h->uni_vmovups(vmm_aux2, vmm_src); + exp_compute_vector(vmm_src); + + // alpha * (exp(x) - 1) + h->uni_vsubps(vmm_src, vmm_src, table_val(0)); + h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off)); + + // combine with mask + if (isa == sse42) { + h->pxor(vmm_mask, vmm_mask); + h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os); + h->blendvps(vmm_src, vmm_aux2); + } else if (isa == avx2) { + h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask); + } else if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2); + } +} + +template +void jit_uni_eltwise_injector_f32::tanh_compute_vector(const Vmm &vmm_src) +{ + // # comes from Taylor expansion error bound + // > linear_sat_point = single(sqrt(3) * 1b-12); + // # comes from the exp formula cancellation + // > exp_bound_point = (single(log(3)/2)); + // # comes from rounding accuracy in float + // > one_sat_point = round(atanh(1 - 1b-25), single, RU); + // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |], + // [linear_sat_point, exp_bound_point], relative, floating); + // > err_bound = D(sup(supnorm(P, tanh(x), + // [linear_sat_point, exp_bound_point], relative, theta))); + // 0x1.fffd6f00b9539p-25 + // > P; + // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 * + // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5 + // + x^0x1p1 * 0x1.09fa1p-6)))) + + // register mapping + // vmm_src contains input + // vmm_aux0 contains mask of currently valid results. + // 1 is need computation, 0 is already computed + // vmm_aux1 contains current output + // vmm_aux2, vmm_aux3 contains auxiliary values + // vmm_aux4 contains the original sign of inputs + + Label end_tanh_label; + + auto test_exit =[&](Xbyak::Address threshold){ + // is not necessary for >AVX, but should not matter on perf + h->uni_vmovups(vmm_aux0, vmm_src); + if (isa == avx512_common){ + h->vcmpps(k_mask, vmm_aux0, threshold, 0x5); + h->kortestw(k_mask, k_mask); + } else { + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold); + h->uni_vtestps(vmm_aux0, vmm_aux0); + } + h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR); + }; + + auto blend_results=[&](Vmm vmm_partial_res){ + if (isa == avx512_common) + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res); + else + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0); + }; + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + // mov is not necessary for >AVX, but should not matter for performance + h->uni_vmovups(vmm_aux4, vmm_src); + h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12)); + h->uni_vandps(vmm_src, vmm_src, table_val(17)); + + // if x < linear_sat_point for all inputs, we just return the input + h->uni_vmovups(vmm_aux1, vmm_src); + test_exit(table_val(13)); + + // if one of the mask is one, we have to compute an better approx + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2); + h->uni_vmovups(vmm_aux3, table_val(22)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18)); + h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src); + + // we blend only the result that need update + blend_results(vmm_aux3); + + // if x < exp_bound_point, we go to return point + test_exit(table_val(14)); + + // if not we use a better approx 1 - 2 / (1 + exp(2x)) + // compute 2x + h->uni_vmovups(vmm_aux3, vmm_src); + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3); + + // Compute exp(2x) + // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them + // vmm_src is not more read afterwards, so we do not have to save it + auto stack_size = 3 * vlen + (isa == avx512_common) * 4; + h->sub(h->rsp, stack_size); + h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0); + h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1); + h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src); + if (isa == avx512_common) + h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask); + + exp_compute_vector(vmm_aux3); + + h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]); + h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]); + h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]); + if (isa == avx512_common) + h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]); + h->add(h->rsp, stack_size); + + // 1 + exp(2x) + h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0)); + + // 1 - 2 / (1 + exp(2x)) + h->uni_vmovups(vmm_aux2, table_val(16)); + h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3); + h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0)); + + // we blend only the result that need update + blend_results(vmm_aux2); + + // finally, we saturate to 1 if needed + // TODO: maybe move that up if most inputs saturate in practice + if (isa == avx512_common) + h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5); + else { + h->uni_vmovups(vmm_aux0, vmm_src); + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15)); + } + h->uni_vmovups(vmm_aux2, table_val(0)); + blend_results(vmm_aux2); + + h->L(end_tanh_label); + { + // we apply the sign of x to the result and we are done + h->uni_vmovups(vmm_src, vmm_aux1); + h->uni_vpxor(vmm_src, vmm_src, vmm_aux4); + } +} + +template +void jit_uni_eltwise_injector_f32::square_compute_vector( + const Vmm &vmm_src) { + h->uni_vmulps(vmm_src, vmm_src, vmm_src); +} + +template +void jit_uni_eltwise_injector_f32::abs_compute_vector(const Vmm &vmm_src) { + // compute abs(x) = _mm_and_ps(x, 01111..111)); + h->uni_vandps(vmm_src, vmm_src, table_val(0)); +} + +template +void jit_uni_eltwise_injector_f32::sqrt_compute_vector(const Vmm &vmm_src) +{ + if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } else { + h->uni_vmovups(vmm_mask, vmm_src); + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0)); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } +} + +template +void jit_uni_eltwise_injector_f32::linear_compute_vector( + const Vmm &vmm_src) { + // compute x = alpha * x + beta; + h->uni_vmovups(vmm_aux0, table_val(0)); + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1)); +} + +template +void jit_uni_eltwise_injector_f32::bounded_relu_compute_vector( + const Vmm &vmm_src) { + // compute bounded relu */ + h->uni_vmaxps(vmm_src, vmm_src, table_val(1)); + h->uni_vminps(vmm_src, vmm_src, table_val(0)); +} + +template +void jit_uni_eltwise_injector_f32::soft_relu_compute_vector( + const Vmm &vmm_src) { + // duplicate src + h->uni_vmovups(vmm_aux2, vmm_src); + + h->uni_vminps(vmm_src, vmm_src, table_val(24)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(25)); + h->uni_vmovups(vmm_aux1, vmm_src); + // calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux0, vmm_aux0); + + h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux0, vmm_src, _op_floor); + } + + // keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx + // calculation fx * ln2 + h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3)); + // x = x - fx * ln2 + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0); + // y = p5 + h->uni_vmovups(vmm_aux3, table_val(22)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17)); + + // compute 2^(-n) + if (isa == avx512_common) { + h->vmulps(vmm_aux1, vmm_src, table_val(23)); + h->vcvtps2dq(vmm_aux1, vmm_aux1); + } else { + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23)); + } + + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx + // calculate ln(1 + y) + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); + // x = y; y is free; keep x for further computations + h->uni_vmovups(vmm_src, vmm_aux3); + // frexp() + h->uni_vpsrld(vmm_src, vmm_src, 23); + h->uni_vcvtdq2ps(vmm_src, vmm_src); + // got n. where n is x = 2^n * y. y = 0.5 .. 1 + h->uni_vsubps(vmm_src, vmm_src, table_val(5)); + + h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6)); + // got y. (mantisa) 0.5 < y < 1 + h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7)); + // y = y - 1 + h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0)); + // y = p8 + h->uni_vmovups(vmm_aux1, table_val(16)); + // y = y * x + p7 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15)); + // y = y * x + p6 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14)); + // y = y * x + p5 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9)); + // y = y * x + p0 ; p0 = 0 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8)); + //calculate ln(2) * n + h->uni_vmulps(vmm_src, vmm_src, table_val(3)); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0); + + // get vmm_mask = src > max logf + h->uni_vmovups(vmm_mask, vmm_aux2); + if (isa == avx512_common) { + // y = (x < max log f) ? soft_relu(x) : x + h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us); + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2); + } else { + // y = (x < max log f) ? soft_relu(x) : x + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24)); + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask); + } + + h->uni_vmovups(vmm_src, vmm_aux1); +} + +template +void jit_uni_eltwise_injector_f32::logistic_compute_vector( + const Vmm &vmm_src) { + // we store the original sign and make x negative + // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required + // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it. + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12)); + h->uni_vorps(vmm_src, vmm_src, table_val(12)); + + exp_compute_vector(vmm_src); + // dup exp(x) + h->uni_vmovups(vmm_aux1, vmm_src); + // (exp(x) + 1) + h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0)); + // y = exp(x) / (exp(x) + 1) + h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); + + // Now we have to apply the "symmetry" based on original sign + h->uni_vmovups(vmm_aux3, table_val(0)); + h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); + if (isa == avx512_common) { + h->vptestmd(k_mask, vmm_aux2, vmm_aux2); + h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src); + } else { + h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2 + h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0); + } + h->uni_vmovups(vmm_src, vmm_aux3); +} + +template +void jit_uni_eltwise_injector_f32::relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::elu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + // exp(x) polynom + 0x3f800001, // [5] p0 = 1.0000001f + 0x3efffe85, // [6] p2 = 0.4999887f + 0x3e2aaa3e, // [7] p3 = 0.16666505f + 0x3d2bb1b1, // [8] p4 = 0.041917507f + 0x3c091ec1, // [9] p5 = 0.008369149f + 0x42b0c0a5, //[10] max logf = 88.3762589f + 0xc1766666, //[11] min logf = -14.5f + // tanh(x) constants, + 0x80000000, //[12] mask to extract sign + 0x39ddb3d7, //[13] arg below which tanh(x) = x + 0x3f0c9f54, //[14] arg below which pol approx is valid + 0x41102cb4, //[15] arg after which tanh(x) = 1 + 0xc0000000, //[16] -2.0f + 0x7fffffff, //[17] mask to make positive + // tanh pol approx + 0x3f7fffff, //[18] p0 + 0xbeaaa9cf, //[19] p1 + 0x3e085f1f, //[20] p2 + 0xbd572bda, //[21] p3 + 0x3c84fd08, //[22] p4 + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]); + } + + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::soft_relu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + 0x42fc0000, // [5] 126 + 0x807fffff, // [6] and with (to get 0.5 * mantissa) + 0x3f000000, // [7] or with (to get 0.5 * mantissa) + // ln(1 + x) polynomial + 0xb2b4637d, // [8] p0 = 0.0000000244f + 0x3f7fff8e, // [9] p1 = 0.9999976971f + 0xbf001759, //[10] p2 = -0.5002478215f + 0x3ea70608, //[11] p3 = 0.3272714505f + 0xbea3d7bf, //[12] p4 = -0.3153830071f + 0xbe361d04, //[13] p5 = -0.1701777461f + 0xbfa8f1e6, //[14] p6 = -1.3254635147f + 0xbfe1e812, //[15] p7 = -1.7971917960f + 0xbfc4d30e, //[16] p8 = -1.5652673123f + // exp(x) polynomial + 0x3f800001, //[17] p0 = 1.0000001f + 0x3f800000, //[18] p1 = 1.0f + 0x3efffe85, //[19] p2 = 0.4999887f + 0x3e2aaa3e, //[20] p3 = 0.16666505f + 0x3d2bb1b1, //[21] p4 = 0.041917507f + 0x3c091ec1, //[22] p5 = 0.008369149f + 0xbf800000, //[23] is required for sign changing + 0x42b0c0a5, //[24] max logf = 88.3762589f + 0xc1766666 //[25] min logf = -14.5f + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) { + h->dd(cvals[i]); + } + } +} + +template +void jit_uni_eltwise_injector_f32::abs_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff); +} + +template +void jit_uni_eltwise_injector_f32::sqrt_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::linear_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_)); +} + +template +void jit_uni_eltwise_injector_f32::bounded_relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +int jit_uni_eltwise_injector_f32::aux_vecs_count(alg_kind_t alg_) { + switch (alg_) { + case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2; + case alg_kind::eltwise_elu: return 4; + case alg_kind::eltwise_tanh: return 5; + case alg_kind::eltwise_square: return 0; + case alg_kind::eltwise_abs: return 0; + case alg_kind::eltwise_sqrt: return 2; + case alg_kind::eltwise_linear: return 1; + case alg_kind::eltwise_bounded_relu: return 0; + case alg_kind::eltwise_soft_relu: return 4; + case alg_kind::eltwise_logistic: return 4; + default: assert(!"unsupported eltwise algorithm"); + } + + return 0; +} + +template +void jit_uni_eltwise_injector_f32::compute_body(size_t start_idx, + size_t end_idx) { + using namespace alg_kind; + for (size_t idx = start_idx; idx < end_idx; idx++) { + switch (alg_) { + case eltwise_relu: + if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx)); + else relu_compute_vector(Vmm(idx)); + break; + case eltwise_elu: elu_compute_vector(Vmm(idx)); break; + case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break; + case eltwise_square: square_compute_vector(Vmm(idx)); break; + case eltwise_abs: abs_compute_vector(Vmm(idx)); break; + case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break; + case eltwise_linear: linear_compute_vector(Vmm(idx)); break; + case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break; + case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break; + case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template +void jit_uni_eltwise_injector_f32::compute_vector_range(size_t start_idx, + size_t end_idx) { + assert(start_idx < end_idx && end_idx <= vecs_count); + + injector_preamble(start_idx, end_idx); + compute_body(start_idx_tail, end_idx); + injector_preamble_tail(start_idx); + compute_body(start_idx, start_idx_tail); + injector_postamble(); +} + +template +void jit_uni_eltwise_injector_f32::prepare_table(bool gen_table) { + using namespace alg_kind; + + h->align(64); + h->L(l_table); + + if (gen_table) { + switch (alg_) { + case eltwise_relu: relu_prepare_table(); break; + case eltwise_elu: + case eltwise_tanh: + case eltwise_logistic: + elu_prepare_table(); break; + case eltwise_soft_relu: soft_relu_prepare_table(); break; + case eltwise_abs: abs_prepare_table(); break; + case eltwise_sqrt: sqrt_prepare_table(); break; + case eltwise_linear: linear_prepare_table(); break; + case eltwise_bounded_relu: bounded_relu_prepare_table(); break; + case eltwise_square: break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template struct jit_uni_eltwise_injector_f32; +template struct jit_uni_eltwise_injector_f32; +template struct jit_uni_eltwise_injector_f32; + + +struct jit_args { + const float *from; + const float *for_comparison; + const float *to; + size_t work_amount; +}; + +struct jit_uni_eltwise_kernel_f32 : public c_compatible { + const eltwise_desc_t &desc_; + + void (*ker_)(const jit_args *); + void operator()(const jit_args *args) { assert(ker_); ker_(args); } + + jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc) + : desc_(desc), ker_(nullptr) {} + virtual ~jit_uni_eltwise_kernel_f32() {} + +protected: + bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; } +}; + +/* jit kernels */ +namespace { + +template +struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32) + + void compute_step(bool vectorize, const int uf, const int shift) { + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + uni_vmovups(Vmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } else { + movss(Xmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + movss(Xmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } + } + + if (isa == sse42) { + for (int i = 0; i < uf; i++) { + movups(Vmm(2 * uf + i + 1), Vmm(i + 1)); + mulps(Vmm(2 * uf + i + 1), vmm_ns); + + Vmm mask = Vmm(0); + if (is_bwd()) { + movups(mask, Vmm(uf + i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } else { + movups(mask, Vmm(i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } + blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1)); + } + } else { + for (int i = 0; i < uf; i++) { + vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns); + if (isa == avx2) { + if (is_bwd()) + vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero); + else + vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero); + + vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1), + Vmm(i + 1), vmm_mask); + + } else { + if (is_bwd()) + vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us); + else + vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us); + vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1), + Vmm(i + 1)); + } + } + } + + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1)); + } else { + movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1)); + } + } + } + + jit_uni_relu_kernel_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + assert(desc.alg_kind == alg_kind::eltwise_relu); + assert(isa == sse42 || isa == avx2 || isa == avx512_common); + + Reg64 param = abi_param1; + + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int loop_dec[] = {simd_w, 1}; + const int uf[] = {1, 1}; + const int shift[] = {cpu_isa_traits::vlen, sizeof(float)}; + const bool loop_vectorize[] = {true, false}; + + this->preamble(); + + mov(reg_from, ptr[param + GET_OFF(from)]); + if (is_bwd()) + mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + + mov(imm_addr64, float2int(desc.alpha)); + movq(xmm_ns, imm_addr64); + uni_vbroadcastss(vmm_ns, xmm_ns); + + uni_vpxor(vmm_zero, vmm_zero, vmm_zero); + + Label loop_label[3]; + + for (int id = 0; id < 2; id++) { + L(loop_label[id]); + cmp(reg_work_amount, uf[id] * loop_dec[id] - 1); + jle(loop_label[id + 1], T_NEAR); + + compute_step(loop_vectorize[id], uf[id], shift[id]); + + add(reg_from, uf[id] * shift[id]); + add(reg_to, uf[id] * shift[id]); + if (is_bwd()) + add(reg_for_comparison, uf[id] * shift[id]); + + sub(reg_work_amount, uf[id] * loop_dec[id]); + jmp(loop_label[id]); + } + + L(loop_label[2]); + this->postamble(); + + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using Vmm = typename utils::conditional3::type; + + Reg64 reg_from = rax; + Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_ns = Xmm(14); + + Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14); + Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15); + + Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12); + Opmask k_mask = Opmask(1); +}; + +template +struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32, + public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32) + + jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1)); + + using namespace alg_kind; + + assert(is_bwd() == false); + assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + + preamble(); + + Reg64 param = abi_param1; + mov(reg_from, ptr[param + GET_OFF(from)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + eltwise_injector_->load_table_addr(); + + Label reminder_loop_start, reminder_loop_end; + Label vectorized_loop_start, vectorized_loop_end; + + cmp(reg_work_amount, simd_w); + jl(reminder_loop_start, T_NEAR); + + L(vectorized_loop_start); + + uni_vmovups(vmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(vmm_src.getIdx()); + uni_vmovups(ptr[reg_to], vmm_src); + + add(reg_from, vlen); + add(reg_to, vlen); + + sub(reg_work_amount, simd_w); + cmp(reg_work_amount, simd_w); + jge(vectorized_loop_start, T_NEAR); + + L(vectorized_loop_end); + + L(reminder_loop_start); + + cmp(reg_work_amount, 0); + jle(reminder_loop_end, T_NEAR); + + movss(xmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(xmm_src.getIdx()); + movss(ptr[reg_to], xmm_src); + + add(reg_from, sizeof(float)); + add(reg_to, sizeof(float)); + + dec(reg_work_amount); + jmp(reminder_loop_start, T_NEAR); + + L(reminder_loop_end); + + postamble(); + + eltwise_injector_->prepare_table(); + + ker_ = (decltype(ker_))this->getCode(); + } + + ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; } + +private: + using Vmm = typename utils::conditional3::type; + + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int vlen = cpu_isa_traits::vlen; + + Reg64 reg_from = rax; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_src = Xmm(1); + Vmm vmm_src = Vmm(1); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; +}; + +} /* namespace */ + +template +status_t jit_uni_eltwise_fwd_t::pd_t::init() { + using namespace alg_kind; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && utils::everyone_is(data_type::f32, desc()->data_desc.data_type) + && !has_zero_dim_memory() + && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh, + eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, + eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, + eltwise_logistic) + && memory_desc_wrapper(src_md()).is_dense(true) + && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false), + math::eltwise_fwd_preserves_zero(desc()->alg_kind, true)) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template +jit_uni_eltwise_fwd_t::jit_uni_eltwise_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32(desc); break; + default: + kernel_ = new jit_uni_kernel_fwd_f32(desc); + } +} + +template +jit_uni_eltwise_fwd_t::~jit_uni_eltwise_fwd_t() +{ delete kernel_; } + +template +void jit_uni_eltwise_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const size_t nelems = data_d.nelems(true); + + src += data_d.offset0(); + dst += data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &src[start]; + arg.for_comparison = &src[start]; + arg.to = &dst[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template +status_t jit_uni_eltwise_bwd_t::pd_t::init() { + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu) + && src_md()->data_type == data_type::f32 + && !has_zero_dim_memory() + && mayiuse(isa) + && memory_desc_wrapper(src_md()).is_dense() + && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md()) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template +jit_uni_eltwise_bwd_t::jit_uni_eltwise_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32(desc); break; + default: assert(!"unknown eltwise alg_kind"); + } +} + +template +jit_uni_eltwise_bwd_t::~jit_uni_eltwise_bwd_t() +{ delete kernel_; } + +template +void jit_uni_eltwise_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const size_t nelems = data_d.nelems(); + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &diff_dst[start]; + arg.to = &diff_src[start]; + arg.for_comparison = &src[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp new file mode 100644 index 0000000000..45436b9f46 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_ELTWISE_HPP +#define CPU_JIT_UNI_ELTWISE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_eltwise_injector_f32 { + using Vmm = typename utils::conditional3::type; + + jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, + float alpha, float beta, bool save_state = true, + Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : alg_(alg), alpha_(alpha), beta_(beta), h(host) + , save_state_(save_state), p_table(p_table), k_mask(k_mask) + { + using namespace alg_kind; + assert(utils::one_of(isa, sse42, avx2, avx512_common)); + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + } + + // note that eltwise.scale is ignored + jit_uni_eltwise_injector_f32(jit_generator *host, + const post_ops_t::entry_t::eltwise_t &eltwise, + bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, + eltwise.beta, save_state, p_table, k_mask) {} + + void compute_vector_range(size_t start_idx, size_t end_idx); + void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); } + void prepare_table(bool gen_table=true); + void load_table_addr() { h->mov(p_table, l_table); } + + const alg_kind_t alg_; + const float alpha_; + const float beta_; + + jit_generator * const h; + + const bool save_state_; + const Xbyak::Reg64 p_table; + const Xbyak::Opmask k_mask; + Xbyak::Label l_table; + +private: + // if only the injector was inherited from jit_generator... + enum { + _cmp_le_os = jit_generator::_cmp_le_os, + _cmp_nle_us = jit_generator::_cmp_nle_us, + _op_floor = jit_generator::_op_floor, + }; + + size_t vlen = cpu_isa_traits::vlen; + + const static size_t preserved_vecs_max = 5; + + size_t vecs_to_preserve = 0; + size_t vecs_count = isa == avx512_common ? 32 : 16; + size_t preserved_vecs_count = 0; + size_t preserved_vec_idxs[preserved_vecs_max] = {0}; + size_t start_idx_tail = 0; + + Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4; + + Xbyak::Address table_val(int index) + { return h->ptr[p_table + index * vlen]; } + + int aux_vecs_count(alg_kind_t alg); + + void compute_body(size_t start_idx, size_t end_idx); + void injector_preamble(size_t start_idx, size_t end_idx); + void injector_preamble_tail(size_t start_idx); + void injector_postamble(); + void assign_regs(); + + void exp_compute_vector(const Vmm &vmm_src); + void relu_compute_vector(const Vmm &vmm_src); + void relu_zero_ns_compute_vector(const Vmm &vmm_src); + void elu_compute_vector(const Vmm &vmm_src); + void tanh_compute_vector(const Vmm &vmm_src); + void square_compute_vector(const Vmm &vmm_src); + void abs_compute_vector(const Vmm &vmm_src); + void sqrt_compute_vector(const Vmm &vmm_src); + void linear_compute_vector(const Vmm &vmm_src); + void bounded_relu_compute_vector(const Vmm &vmm_src); + void soft_relu_compute_vector(const Vmm &vmm_src); + void logistic_compute_vector(const Vmm &vmm_src); + + void relu_prepare_table(); + void elu_prepare_table(); + void soft_relu_prepare_table(); + void abs_prepare_table(); + void sqrt_prepare_table(); + void linear_prepare_table(); + void bounded_relu_prepare_table(); +}; + +struct jit_uni_eltwise_kernel_f32; + +template +struct jit_uni_eltwise_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_fwd_t); + + status_t init(); + }; + + jit_uni_eltwise_fwd_t(const pd_t *apd); + ~jit_uni_eltwise_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +template +struct jit_uni_eltwise_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_bwd_t); + + status_t init(); + }; + + jit_uni_eltwise_bwd_t(const pd_t *apd); + ~jit_uni_eltwise_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp new file mode 100644 index 0000000000..a3ca6273a0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp @@ -0,0 +1,949 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_uni_i8i8_pooling.hpp" + +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::types; +using namespace alg_kind; + +template +struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t) + + struct call_params_t { + const char *src_i8; + const char *dst_i8; + size_t kw_range; + size_t kh_range; + float idivider; + }; + + using Vmm = typename cpu_isa_traits::Vmm; + Xmm xreg(int idx) const { return Xmm(idx); } + Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); } + + // In case of avx2 with data type i8 we need to use + // maskmovdqu instruction which has its destination hardcoded in rdi. + // Windows ABI: abi_param1 is rcx - nothing to do else + // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1 + Reg64 reg_param = rcx; // Our "unified abi_param1" + Reg64 reg_ptr_src_i8 = r8; + Reg64 reg_ptr_dst_i8 = r9; + Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi + + Reg64 ki = r10; + Reg64 kj = r11; + Reg64 reg_kw = r12; + Reg64 reg_kh = r13; + Reg64 c_iter = r14; + + Reg64 aux_reg_src_h = rax; + Reg64 aux_reg_src_w = rbx; + + Reg64 reg_tmp = rdx; + + Reg64 reg_mask = r15; + + Opmask k_cmp_mask = Opmask(7); + + Opmask mask(int idx) { + return Opmask(6 - idx); + } + + // ref to any of XYZ-regs via xreg/yreg/vreg functions + Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp + Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type + Vmm vreg_zeros = vreg(1); + + // only in case of == avx2 + Vmm vreg_mask = vreg(2); // full byte-mask + Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask) + Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately) + Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations + Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails + + enum:int {vidx_base = isa == avx2 ? 4 : 2}; + Vmm base_vr(int idx) const { return vreg(vidx_base + idx); } + + size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); } + size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); } + + /* max pooling */ + Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1] + Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1] + + /* avg pooling */ + // s32 used for processing of s8/u8 data + // thus we need to take into account ratio of sizes s32/i8 = 4 + static constexpr data_type_t avg_proc_dt = data_type::s32; + enum:int { + s32_to_i8_ratio = sizeof(typename prec_traits::type) + / sizeof(typename prec_traits::type), + max_num_ll = s32_to_i8_ratio + }; + Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3] + Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7] + Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11] + + void (*ker_)(const call_params_t *); + jit_pool_conf_t jpp; + + void init_tmp_reg(); + void init_mask(); + + void load_vreg_mask_q(int ll) {}; + + void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src(int jj, int ll, int c_tail); + + void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst(int jj, int ll, int c_tail); + + void compute_avg_step(int ur_c, int c_tail); + void compute_max_op(const int jj); + void compute_max_step(int ur_c, int c_tail); + void compute_step(int ur_c, int c_tail); + + void compute_c_block(); + void generate(); + + static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd); + + jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_) + : jpp(jpp_) { + generate(); + ker_ = reinterpret_cast(const_cast( + getCode())); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_vreg_mask_q(int ll) { + + // extract ll-th part of mask (ll-th QWORD) + vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD + + // Move mask from ll-th pos to 0-th pos + if (ll>0) + vpermq(vreg_mask_q, vreg_mask_q, ll); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) { + vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast(msk)); + } else { + vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask); + } + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) + vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + else + vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto load_i8 = [&](bool is_signed, const Vmm& vr_src) { + + // Need to use mask of tail? + if (masked) { + + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + + // Load by mask from mem into register vr_src + vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q); + + // Conversion s8/u8 -> s32 + if (is_signed) + vpmovsxbd(vr_src, vr_src); + else + vpmovzxbd(vr_src, vr_src); + } else { + + // Load from mem into vr_src with conversion + if (is_signed) + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + else + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + } + }; + + switch (jpp.src_dt) { + case s32: + if (masked) + vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset], + static_cast(msk)); + else + vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]); + break; + case s8: + load_i8(true, vreg_src_s32(jj, ll)); + break; + case u8: + load_i8(false, vreg_src_s32(jj, ll)); + break; + default: assert(!"unsupported src data type"); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_src = masked ? + vreg_src_s32(jj, ll) | mask(ll) : + vreg_src_s32(jj, ll); + + switch (jpp.src_dt) { + case s32: + vmovups(vr_src, ptr[aux_reg_src_w + offset]); + break; + case s8: + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + case u8: + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + default: assert(!"unsupported src data type"); + } +}; + +template +void jit_uni_i8i8_pooling_fwd_ker_t::load_src(int jj, int ll, int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch (jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_max_op(jj, ll, offset, masked, jpp.tail[0]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + int c_block = jpp.c_block; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj)); + break; + case s8: + case u8: { + // Store low half by mask (bytes 0...15) + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vreg_dst(jj), xreg_mask_lo); + + // Do we need to store high half (bytes 16...31) ? + const uint64_t low_mask = (1ULL << (c_block/2))-1; + if (msk & ~low_mask) { + vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1); + add(reg_ptr_maskmovdqu_dst, c_block / 2); + maskmovdqu(vreg_dst(jj), xreg_mask_hi); + } + } break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + case s8: + case u8: + vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk){ + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) { + + // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16} + // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0} + if (is_signed) + vpackssdw(vr_dst, vr_dst, vreg_zeros); + else + vpackusdw(vr_dst, vr_dst, vreg_zeros); + + // Permute qwords to restore original order + // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0} + vpermq(vr_dst, vr_dst, 0x58); + + // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8} + // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx} + if (is_signed) + vpacksswb(vr_dst, vr_dst, vreg_zeros); + else + vpackuswb(vr_dst, vr_dst, vreg_zeros); + + }; + + auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) { + + // Conversion s32 -> s8/u8 + s32_to_i8(is_signed, vr_dst); + + // Need to use mask of tail? + if (is_masked) { + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + } + + // store 8 bytes + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vr_dst, xreg_mask_q); + }; + + switch (jpp.dst_dt) { + case s32: + if (masked) { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll)); + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll)); + break; + case s8: + store_i8(true, masked, vreg_dst_s32(jj, ll)); + break; + case u8: + store_i8(false, masked, vreg_dst_s32(jj, ll)); + break; + default: assert(!"unsuppotred dst data_type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_dst = masked ? + vreg_dst_s32(jj, ll) | mask(ll) : + vreg_dst_s32(jj, ll); + + switch (jpp.dst_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case s8: + vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case u8: + vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + default: assert(!"unsupported dst data_type"); + } +} + + +template +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst(int jj, int ll, + int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch(jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported pooling algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) +{ + using namespace data_type; + switch (jpp.src_dt) { + case s32: + vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case s8: + vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case u8: + vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + default: assert(!"unsupported src data type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) +{ + using namespace data_type; + + // Compare + switch (jpp.src_dt) { + case s32: + vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case s8: + vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case u8: + vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + default: assert(!"unsupported src data type"); + } + + // move max values into vreg_dst + if (jpp.src_dt == s32) + vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); + else + vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); +} + + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_step(int ur_c, int c_tail) +{ + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + for (int jj = 0; jj < ur_c; jj++) + vmovups(vreg_dst(jj), vreg_tmp); + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + load_src(jj, 0, c_tail); + compute_max_op(jj); + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) + store_dst(jj, 0, c_tail); +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step(int ur_c, int c_tail) +{ + using namespace data_type; + + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt); + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll)); + uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll)); + } + } + } + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + load_src(jj, ll, c_tail); + vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), + vreg_src_s32(jj, ll)); + } + } + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll)); + vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp); + vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll)); + store_dst(jj, ll, c_tail); + } + } + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_step(int ur_c, int c_tail) { + switch (jpp.alg) { + case pooling_max: + compute_max_step(ur_c, c_tail); break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + compute_avg_step(ur_c, c_tail); break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_c_block(){ + Label l_main_loop; + + int nb_c = jpp.nb_c; + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + int ur_c_tail = jpp.ur_c_tail; + int c_steps = nb_c / ur_c; + int c_tail = jpp.c_tail; + + xor_(c_iter, c_iter); + if (c_steps > 0) { + L(l_main_loop); { + compute_step(ur_c, 0); + add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt()); + add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt()); + inc(c_iter); + cmp(c_iter, c_steps); + jl(l_main_loop, T_NEAR); + } + } + + if (ur_c_tail != 0) { + compute_step(ur_c_tail, c_tail); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { + using namespace data_type; + using cpu_isa = cpu_isa_traits; + + // AVX2 mask initialization: mask stored in Ymm-regs + auto init = [&](uint64_t bit_mask, bool init_mask_q) { + const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t); + + uint64_t vmask[QW_PER_VREG]; + for (size_t i = 0; i < QW_PER_VREG; i++){ + + uint64_t qw_vmask=0ULL; + const size_t DBITS = 8*sizeof_src_dt(); + const uint64_t VMSK = 1ULL << (DBITS-1); + const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS; + for (size_t j = 0; j < D_PER_QW; j++) { + if (bit_mask & 1) + qw_vmask |= VMSK << DBITS * j; + bit_mask >>= 1; + } + vmask[i] = qw_vmask; + } + + // Put QWORDS with target mask into xmm regs + const int xdst_i[QW_PER_VREG] = { + xreg_mask_lo.getIdx(), + xreg_mask_lo.getIdx(), + xreg_mask_hi.getIdx(), + xreg_mask_hi.getIdx() + }; + const int xsrc_i[QW_PER_VREG] = { + vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0} + xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1} + vreg_zeros.getIdx(), + xreg_mask_hi.getIdx() + }; + const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg + + for (size_t i = 0; i < QW_PER_VREG; i++) { + mov(reg_mask, vmask[i]); + vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]); + } + + // Merge Low (xreg_mask_lo alias for vreg_mask.xreg) + // and High (xreg_mask_hi) into full vreg_mask + // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg} + vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1); + + // Keep only low qword of mask in xreg_mask_q + if (init_mask_q) { + mov(reg_mask, vmask[0]); + vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0); + } + }; + + uint64_t tail_mask = (1ULL << jpp.c_tail) - 1; + switch (jpp.alg) { + case pooling_max: + // For "max" we need mask only in case of non-zero tail + if (tail_mask) + init(tail_mask, false); + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + // For "avg" we need mask: + // - s32 - in case of the non-zero tail + // - s8/u8 - irrespective of the tail + switch (jpp.src_dt) { + case s32: + if (tail_mask) + init(tail_mask, false); + break; + case s8: + case u8: + init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0); + break; + default: assert(!"unsupported src data type"); + } + break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { + + for (int ll = 0; ll < max_num_ll; ll++) { + mov(reg_mask, jpp.tail[ll]); + kmovq(mask(ll), reg_mask); + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::init_tmp_reg() { + using namespace data_type; + + switch (jpp.alg) { + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]); + movq(xmm_tmp, reg_tmp); + vpbroadcastd(vreg_tmp, xmm_tmp); + break; + case pooling_max: + switch (jpp.src_dt) { + case s32: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + case s8: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + case u8: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + default: assert(!"unsupported src data_type"); + } + + movq(xmm_tmp, reg_tmp); + if (jpp.src_dt == s32) + vpbroadcastd(vreg_tmp, xmm_tmp); + else + vpbroadcastb(vreg_tmp, xmm_tmp); + break; + default: assert(!"unsupported pooling algorithm"); + } + +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::generate() { + preamble(); + +#if !defined(_WIN32) + // Always use rcx as abi_param1 - + // see the note about maskmovdqu near reg_param. + mov(rcx, rdi); +#endif + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[reg_param + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src_i8, src_i8); + READ_PARAM(reg_ptr_dst_i8, dst_i8); + READ_PARAM(reg_kw, kw_range); + READ_PARAM(reg_kh, kh_range); + +# undef READ_PARAM + + uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); + + init_mask(); + + init_tmp_reg(); + + compute_c_block(); + + postamble(); +} + +template +status_t jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + if (!mayiuse(isa)) + return status::unimplemented; + + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d(ppd->src_md()); + const memory_desc_wrapper dst_d(ppd->dst_md()); + + jpp.mb = src_d.dims()[0]; + jpp.c = src_d.dims()[1]; + jpp.ih = src_d.dims()[2]; + jpp.iw = src_d.dims()[3]; + jpp.oh = dst_d.dims()[2]; + jpp.ow = dst_d.dims()[3]; + + jpp.stride_h = pd.strides[0]; + jpp.stride_w = pd.strides[1]; + jpp.kh = pd.kernel[0]; + jpp.kw = pd.kernel[1]; + + jpp.t_pad = pd.padding[0][0]; + jpp.l_pad = pd.padding[0][1]; + + jpp.alg = pd.alg_kind; + + jpp.src_dt = pd.src_desc.data_type; + jpp.dst_dt = pd.dst_desc.data_type; + + // data_type items per one vreg on the + // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 + // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 + int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.src_dt); + + jpp.c_block = simd_w; + jpp.c_tail = jpp.c % jpp.c_block; + jpp.nb_c = jpp.c / jpp.c_block; + jpp.ur_c = 1; + jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + + (jpp.c_tail != 0); + + size_t tail_mask = (1ULL << jpp.c_tail) - 1; + + switch (jpp.alg) { + case pooling_max: + jpp.tail[0] = tail_mask; + jpp.tail[1] = 0; + jpp.tail[2] = 0; + jpp.tail[3] = 0; + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32) + // avx2 : 8, avx512 : 16 + const size_t msk_gran = cpu_isa_traits::vlen / data_type_size(avg_proc_dt); + const size_t msk_msk = (1ULL << msk_gran) - 1; + size_t m = tail_mask; + for (size_t ll = 0; ll < max_num_ll; ll++) { + jpp.tail[ll] = m & msk_msk; + m = m >> msk_gran; + } + break; + } + default: return status::unimplemented; + } + + return status::success; +} + +template +status_t jit_uni_i8i8_pooling_fwd_t::pd_t::jit_conf() { + return jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jpp_, this); +} + +template +jit_uni_i8i8_pooling_fwd_t:: +jit_uni_i8i8_pooling_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) +{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t(pd()->jpp_); } + +template +jit_uni_i8i8_pooling_fwd_t:: +~jit_uni_i8i8_pooling_fwd_t() { delete ker_; } + +template +void jit_uni_i8i8_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src_i8 = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC); + auto dst_i8 = CTX_OUT_MEM(char *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const auto &jpp = pd()->jpp_; + + parallel_nd(jpp.mb, jpp.oh, jpp.ow, + [&](int n, int oh, int ow) { + const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0); + const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0); + + const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h); + const int kh_end = nstl::min(jpp.kh, + jpp.ih + jpp.t_pad - oh * jpp.stride_h); + const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w); + const int kw_end = nstl::min(jpp.kw, + jpp.iw + jpp.l_pad - ow * jpp.stride_w); + + auto p = typename jit_uni_i8i8_pooling_fwd_ker_t::call_params_t(); + p.src_i8 = &src_i8[ + src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()]; + p.dst_i8 = &dst_i8[ + dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()]; + p.kw_range = (size_t)(kw_end - kw_start); + p.kh_range = (size_t)(kh_end - kh_start); + p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ? + p.kh_range*p.kw_range : jpp.kw*jpp.kh); + + ker_->ker_(&p); + }); +} + +// Explicit instantiation only for supported values. +// +template struct jit_uni_i8i8_pooling_fwd_ker_t; +template struct jit_uni_i8i8_pooling_fwd_t; + +template struct jit_uni_i8i8_pooling_fwd_ker_t; +template struct jit_uni_i8i8_pooling_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp new file mode 100644 index 0000000000..d757679df5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_I8I8_POOLING_HPP +#define CPU_JIT_UNI_I8I8_POOLING_HPP + +#include "c_types_map.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_i8i8_pooling_fwd_ker_t; + +template +struct jit_uni_i8i8_pooling_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_i8i8_pooling_fwd_t); + + status_t init() { + bool ok = true + && mayiuse(isa) + && ndims() == 4 + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::one_of(src_md()->data_type, data_type::s32, + data_type::s8, data_type::u8) + && src_md()->data_type == dst_md()->data_type + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*dst_md(), format_tag::nhwc); + if (!ok) return status::unimplemented; + + return jit_conf(); + } + + jit_pool_conf_t jpp_; + + protected: + status_t jit_conf(); + }; + + jit_uni_i8i8_pooling_fwd_t(const pd_t *apd); + ~jit_uni_i8i8_pooling_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_i8i8_pooling_fwd_ker_t *ker_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp new file mode 100644 index 0000000000..2c5a8e8973 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp @@ -0,0 +1,305 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" +#include "jit_uni_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +template +jit_uni_lrn_fwd_t::jit_uni_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) + , ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float K = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, 0), A, K, pk); + ker_first_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, -1), A, K, pk); + ker_last_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, +1), A, K, pk); + } else if (dat_tag == nChw8c && ak == lrn_within_channel) { + /* within channel, local_size (x) local_size */ + A /= ls; /* XXX: why? */ + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_within(H, W, ls), A, K, pk); + } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw_across(C, H*W, 0), A, K, pk); + int remind = (H*W) % VECTOR_LENGTH; + if (remind != 0) { + ker_last_ = new jit_uni_lrn_fwd_kernel_f32( + nchw_across(C, H*W, remind), A, K, pk); + } + } else if (true /* XXX: why */) { + ker_ = new jit_uni_lrn_fwd_kernel_f32(nhwc_across(C), A, K, pk); + } +} + +template +jit_uni_lrn_fwd_t::~jit_uni_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +template +void jit_uni_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int HW = pd()->H() * pd()->W(); + const int ls = pd()->desc()->local_size; + + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else if (dat_tag == nChw8c && ak == lrn_within_channel) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + (*ker_)(&args); + }); + } + else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH, + [&](int n, int hw8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH]; + if ((hw8 + 1)*VECTOR_LENGTH > HW) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { // nhwc + parallel_nd(N, HW, [&](int n, int hw) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw * C]; + args.dst = &dst[n*HW*C + hw * C]; + args.scratch = &ws[n*HW*C + hw * C]; + (*ker_)(&args); + }); + } +} + +template +status_t jit_uni_lrn_fwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && is_fwd() + && everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && data_d.dims()[1] >= 2 * VECTOR_LENGTH + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc_.prop_kind == forward_training) ws_md_ = *src_md(); + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && one_of(dat_tag_, nChw8c, nchw, nhwc); + + const int jit_max_local_size = 5; // bigger size triggers too big code size + bool args_ok_within = true + && desc()->alg_kind == lrn_within_channel + && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE + ? jit_max_local_size : MAX_LOCAL_SIZE) + && data_d.dims()[2] >= desc()->local_size + && data_d.dims()[3] >= desc()->local_size + && one_of(dat_tag_, nChw8c); + + return args_ok_across || args_ok_within ? success : unimplemented; +} + +template +jit_uni_lrn_bwd_t::jit_uni_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float B = pd()->desc()->lrn_beta; + + int use_h_parallelizm = 0;// XXX + if (C / VECTOR_LENGTH == 1) { + ker_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, 3), A, B, use_h_parallelizm); + } + else { + ker_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, 0), A, B, use_h_parallelizm); + ker_first_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, -1), A, B, use_h_parallelizm); + ker_last_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, +1), A, B, use_h_parallelizm); + } +} + +template +jit_uni_lrn_bwd_t::~jit_uni_lrn_bwd_t() +{ + delete ker_; delete ker_first_; delete ker_last_; +} + +template +void jit_uni_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + int use_h_parallelizm = 0; // XXX + if (use_h_parallelizm) { + parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH + + h*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } +} + +template +status_t jit_uni_lrn_bwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + ws_md_ = *src_md(); + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && utils::one_of(dat_tag_, nChw8c); + + return args_ok_across ? success : unimplemented; +} + +template struct jit_uni_lrn_fwd_t; +template struct jit_uni_lrn_fwd_t; +template struct jit_uni_lrn_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp new file mode 100644 index 0000000000..333cd3396d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_HPP +#define CPU_JIT_UNI_LRN_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_bwd_kernel_f32; + +template +struct jit_uni_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_fwd_t); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_fwd_t(const pd_t *apd); + ~jit_uni_lrn_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_fwd_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +template +struct jit_uni_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_bwd_t); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_bwd_t(const pd_t *apd); + ~jit_uni_lrn_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_bwd_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp new file mode 100644 index 0000000000..89af47272c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp @@ -0,0 +1,1487 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +////////////////////////////////////////////////////////////////////////////// +// forward kernel +template +void jit_uni_lrn_fwd_kernel_f32::within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk) +{ + vxorps(ysum, ysum, ysum); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + vmovups(ydst, ptr[src]); + vfmadd231ps(ysum, ydst, ydst); + } + else + { + vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]); + vfmadd231ps(ysum, ytmp, ytmp); + } + } + } + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + vmovaps(ytmp, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ytmp); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75 + vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum + vmovups(ptr[dst], ydst); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template +void jit_uni_lrn_fwd_kernel_f32::within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk) +{ + Xbyak::Xmm xtmp_lo = xmm12; + Xbyak::Xmm xtmp_hi = xmm13; + Xbyak::Xmm xsum_lo = xmm8; + Xbyak::Xmm xsum_hi = xmm9; + Xbyak::Xmm xdst_lo = xmm10; + Xbyak::Xmm xdst_hi = xmm11; + Xbyak::Xmm xsum2_lo = xmm14; + Xbyak::Xmm xsum2_hi = xmm15; + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + addps(xsum_lo, xdst_lo); + addps(xsum_hi, xdst_hi); + } + else + { + movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]); + movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]); + mulps(xtmp_lo, xtmp_lo); + mulps(xtmp_hi, xtmp_hi); + addps(xsum_lo, xtmp_lo); + addps(xsum_hi, xtmp_hi); + } + } + } + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + movaps(xtmp_lo, xsum_lo); + movaps(xtmp_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xtmp_lo); + movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi); + } + movaps(xsum2_lo, xsum_lo); + movaps(xsum2_hi, xsum_hi); + mulps(xsum2_lo, xsum_lo); + mulps(xsum2_hi, xsum_hi); + mulps(xsum_lo, xsum2_lo); + mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3; + + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75 + + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + divps(xdst_lo, xsum_lo); + divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum + + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 h = r9; + Xbyak::Reg64 w = r10; + Vmm ysum = Vmm(isa == avx2 ? 9 : 9); + Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10); + Vmm ydst = Vmm(isa == avx2 ? 11 : 11); + Vmm ytmp = Vmm(isa == avx2 ? 12 : 12); + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + if (isa == avx2) { + vbroadcastss(yalpha, xalpha); + } else { + shufps(xalpha, xalpha, 0); + } + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + if (isa == avx2) { + vbroadcastss(yk, xk); + } else { + shufps(xk, xk, 0); + } + + int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1; + + for (int i = 0; i < s2; ++i) + { + Label label_t; + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } + else { + within_body_sse42(-i, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + L(label_t); + if (isa == avx2) { + within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_t, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-i, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + mov(h, J.H - J.size + 1); + Label lrn_loop_h; + L(lrn_loop_h); + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + Label lrn_loop_w; + L(lrn_loop_w); + if (isa == avx2) { + within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(lrn_loop_w, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + dec(h); + cmp(h, 0); + jne(lrn_loop_h, T_NEAR); + + for (int i = J.H - S2; i < J.H; ++i) + { + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -j, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk); + } + } + + mov(w, J.W - J.size + 1); + Label label_b; + L(label_b); + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_b, T_NEAR); + + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + Xbyak::Xmm xsrc_prev = xmm2; + Xbyak::Ymm ysrc = ymm3; + Xbyak::Ymm yc = ymm3; + Xbyak::Xmm xsrc_next = xmm4; + Xbyak::Ymm ya = ymm5; + Xbyak::Ymm yb = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ye = ymm8; + Xbyak::Ymm ysum = ymm9; + Xbyak::Ymm ysum2 = ymm10; + Xbyak::Ymm ydst = ymm11; + Xbyak::Ymm ybase = ymm12; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + if (J.version == -1) + { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(ysrc, ptr[src]); + if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev); + vmovups(ptr[t + 16], ysrc); + if (J.version != +1) vmovups(ptr[t + 48], xsrc_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vmulps(ysum, yc, yc); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya + vfmadd231ps(ysum, yb, yb); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + + vmovaps(ybase, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = ybase^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = ybase^0.75 + vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum + vmovups(ptr[dst], ydst); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + + Xbyak::Xmm xsrc_lo = xmm2; + Xbyak::Xmm xsrc_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xsum_lo = xc_lo; + Xbyak::Xmm xsum_hi = xc_hi; + Xbyak::Xmm xsrc_prev = xmm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xa_lo = xmm8; + Xbyak::Xmm xa_hi = xmm9; + Xbyak::Xmm xb_lo = xmm10; + Xbyak::Xmm xb_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + if (J.version == -1) + { + xorps(xsrc_prev, xsrc_prev); + movups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + xorps(xsrc_next, xsrc_next); + movups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + movups(xsrc_lo, ptr[src]); + movups(xsrc_hi, ptr[src + 4 * sizeof(float)]); + if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) movups(ptr[t + 0], xsrc_prev); + movups(ptr[t + 16], xsrc_lo); + movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi); + if (J.version != +1) movups(ptr[t + 48], xsrc_next); + + movups(xa_lo, ptr[t + 16 - 8]); + movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]); + movups(xb_lo, ptr[t + 16 - 4]); + movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]); + movups(xd_lo, ptr[t + 16 + 4]); + movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[t + 16 + 8]); + movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]); + movaps(xc_lo, xsrc_lo); + movaps(xc_hi, xsrc_hi); + mulps(xsum_lo, xc_lo); + mulps(xsum_hi, xc_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + + movaps(xbase_lo, xsum_lo); + movaps(xbase_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xsum_lo, xsum_lo); + mulps(xsum_hi, xsum_hi); + mulps(xsum_lo, xbase_lo); + mulps(xsum_hi, xbase_hi); // xsum = xbase^3; + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75 + divps(xsrc_lo, xsum_lo); + divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum + movups(ptr[dst], xsrc_lo); + movups(ptr[dst + 4 * sizeof(float)], xsrc_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0x80000000, 0, 0 + }; + + Xbyak::Reg64 c = r9; + Xbyak::Ymm ya = ymm2; + Xbyak::Ymm yb = ymm3; + Xbyak::Ymm yc = ymm4; + Xbyak::Ymm yd = ymm5; + Xbyak::Ymm ye = ymm6; + Xbyak::Ymm ysum = ymm7; + Xbyak::Ymm ydst = ymm8; + Xbyak::Ymm ybase = ymm9; + Xbyak::Ymm ymask = ymm10; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + vxorps(ysum, ysum, ysum); + + mov(imm_addr64, reinterpret_cast(&mask[0])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ya, ymask, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast(&mask[1])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yb, ymask, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + vmovups(yc, ptr[src]); + vmovups(yd, ptr[src + 4]); + vmovups(ye, ptr[src + 8]); + vfmadd231ps(ysum, yc, yc); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + vmovups(ptr[dst], ydst); + + vxorps(ysum, ysum, ysum); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + vmovups(ya, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); + vmovups(yb, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vmovups(yc, ptr[src]); + vfmadd231ps(ysum, yc, yc); + + mov(imm_addr64, reinterpret_cast(&mask[2])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yd, ymask, ptr[src + 4]); + vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast(&mask[3])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ye, ymask, ptr[src + 8]); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + vmovups(ptr[dst], ydst); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0, 0 + }; + + static uint32_t store[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r9; + + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xa_lo = xmm2; + Xbyak::Xmm xa_hi = xmm3; + Xbyak::Xmm xb_lo = xmm2; + Xbyak::Xmm xb_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xd_lo = xmm6; + Xbyak::Xmm xd_hi = xmm7; + Xbyak::Xmm xe_lo = xmm8; + Xbyak::Xmm xe_hi = xmm9; + Xbyak::Xmm xsum_lo = xmm10; + Xbyak::Xmm xsum_hi = xmm11; + Xbyak::Xmm xmask_lo = xmm12; + Xbyak::Xmm xmask_hi = xmm13; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + mov(store_addr, reinterpret_cast(&store[0])); + and_(store_addr, -15); + movups(ptr[store_addr], xalpha); + movups(ptr[store_addr + 4 * sizeof(float)], xk); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + mov(imm_addr64, reinterpret_cast(&mask[0])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + andps(xa_lo, xmask_lo); + andps(xa_hi, xmask_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast(&mask[1])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + andps(xb_lo, xmask_lo); + andps(xb_hi, xmask_hi); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + + mov(imm_addr64, reinterpret_cast(&mask[2])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast(&mask[3])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movups(xdst_lo, xsum_lo); + movups(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) {} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) +{ + Xbyak::Ymm ydst = ymm14; + Xbyak::Ymm ybase = ymm15; + + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) + vmaskmovps(ptr[scratch], ymask, ybase); + else + vmovups(ptr[scratch], ybase); + } + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + if (tail != 0) + vmaskmovps(ptr[dst], ymask, ydst); + else + vmovups(ptr[dst], ydst); + + + vfnmadd231ps(ysum, ya, ya); + vmovups(ya, yb); + vmovups(yb, yc); + vmovups(yc, yd); + vmovups(yd, ye); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{ + Xbyak::Xmm xmm_tmp = xmm10; + movaps(xmm_tmp, xtail_lo); + size_t offset = 0; + + if (tail > 4) { + movups(ptr[reg_dst], xtail_lo); + movaps(xmm_tmp, xtail_hi); + offset += 4 * sizeof(float); + tail -= 4; + } + movss(ptr[reg_dst + offset], xmm_tmp); + for (int i = 1; i < tail; i++) + { + psrldq(xmm_tmp, 4); + movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp); + } +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) +{ + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xbase_lo = xmm6; + Xbyak::Xmm xbase_hi = xmm7; + Xbyak::Xmm xtmp_lo = xmm8; + Xbyak::Xmm xtmp_hi = xmm9; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + + // store xe + movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi); + + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + // xdst <- xsum*xalpha+xk + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]); + mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]); + addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) { + nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi); + } + else { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + divps(xtmp_lo, xdst_lo); + divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movaps(xdst_lo, xtmp_lo); + movaps(xdst_hi, xtmp_hi); + + if (tail != 0) { + nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi); + } + else { + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + } + + movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]); + movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + subps(xsum_lo, xa_lo); + subps(xsum_hi, xa_hi); + + // xa <- xb + movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]); + movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi); + + // xb <- xc + movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi); + + // xc <- xd + movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]); + movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi); + + // xd <- xe + movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]); + movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r10; + Xbyak::Ymm ymask = ymm2; + Xbyak::Ymm ye = ymm3; + Xbyak::Ymm ya = ymm4; + Xbyak::Ymm yb = ymm5; + Xbyak::Ymm yc = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ysum = ymm8; + + this->preamble(); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); + vmovups(ymask, ptr[imm_addr64]); + } + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + vxorps(ya, ya, ya); + vxorps(yb, yb, yb); + if (J.tail != 0) + vmaskmovps(yc, ymask, ptr[src + J.HW * 0]); + else + vmovups(yc, ptr[src + J.HW * 0]); + if (J.tail != 0) + vmaskmovps(yd, ymask, ptr[src + J.HW * 4]); + else + vmovups(yd, ptr[src + J.HW * 4]); + + vxorps(ysum, ysum, ysum); + vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + vfmadd231ps(ysum, yd, yd); + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) + vmaskmovps(ye, ymask, ptr[src + J.HW * 8]); + else + vmovups(ye, ptr[src + J.HW * 8]); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vxorps(ye, ye, ye); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0 + }; + + Xbyak::Reg64 c = r10; + + Xbyak::Xmm xmask_lo = xmm2; + Xbyak::Xmm xmask_hi = xmm3; + Xbyak::Xmm xsum_lo = xmm4; + Xbyak::Xmm xsum_hi = xmm5; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + sub(rsp, stack_space_needed); + mov(store_addr, rsp); + and_(store_addr, -15); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + // put alpha and k into store (free up regs) + movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha); + movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + } + // init xa, xb + xorps(xa_lo, xa_lo); + xorps(xa_hi, xa_hi); + xorps(xb_lo, xb_lo); + xorps(xb_hi, xb_hi); + + // read xc, xd + if (J.tail != 0) { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + andps(xc_lo, xmask_lo); + andps(xc_hi, xmask_hi); + } + else { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + } + if (J.tail != 0) { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + } + else { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + } + + // put xa, xb, xc, xd into store to free-up regs + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + } + else { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + } + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + xorps(xe_lo, xe_lo); + xorps(xe_hi, xe_hi); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(rsp, stack_space_needed); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +////////////////////////////////////////////////////////////////////////////// +// backward kernel +template +jit_uni_lrn_bwd_kernel_f32::jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2 * A*B) + , use_h_parallelizm(use_h_parallel) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r10; + + Xbyak::Xmm xsrc_prev = xmm1; + Xbyak::Xmm xws_prev = xmm2; + Xbyak::Xmm xdiffdst_prev = xmm3; + Xbyak::Ymm ysrc = ymm4; + Xbyak::Ymm yws = ymm5; + Xbyak::Ymm ydiffdst = ymm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xws_next = xmm8; + Xbyak::Xmm xdiffdst_next = xmm9; + Xbyak::Ymm ya = ymm10; + Xbyak::Xmm xa = xmm10; + Xbyak::Ymm yb = ymm11; + Xbyak::Ymm yd = ymm12; + Xbyak::Ymm ye = ymm13; + Xbyak::Ymm ysum = ymm14; + Xbyak::Ymm ydiffsrc = ymm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(diffdst, ptr[this->param1 + 8]); + mov(workspace, ptr[this->param1 + 16]); + mov(diffsrc, ptr[this->param1 + 24]); + + sub(t, 64); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(ynalphabeta, xnalphabeta); + + bool is_single = J.version == 3; + bool is_first = J.version == -1 || J.version == -2; + bool is_last = J.version == +1 || J.version == -2; + + if (is_first || is_single) { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (is_last || is_single) { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W); + Label lrn_loop; + L(lrn_loop); + { + if (!is_first && !is_single) { + vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]); + vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]); + vmulps(xa, xws_prev, xws_prev); + vmulps(xa, xa, xws_prev); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_prev); + vdivps(xsrc_prev, xsrc_prev, xa); + vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev); + } + + vmovups(ysrc, ptr[src]); + vmovups(yws, ptr[workspace]); + vmovups(ydiffdst, ptr[diffdst]); + vmulps(ya, yws, yws); + vmulps(ya, ya, yws); + vsqrtps(ya, ya); + vsqrtps(ya, ya); + vdivps(ydiffsrc, ydiffdst, ya); + vdivps(ysum, ydiffsrc, yws); + vmulps(ysum, ysum, ysrc); + + if (!is_last && !is_single) { + vmovups(xws_next, ptr[workspace + J.H*J.W * 32]); + vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]); + vmulps(xa, xws_next, xws_next); + vmulps(xa, xa, xws_next); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_next); + vdivps(xsrc_next, xsrc_next, xa); + vdivps(xsrc_next, xsrc_next, xws_next); + vmulps(xdiffdst_next, xdiffdst_next, xsrc_next); + } + + if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev); + vmovups(ptr[t + 16], ysum); + if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vaddps(ysum, ysum, ya); + vmulps(ysrc, ysrc, ynalphabeta); + vaddps(ysum, ysum, yb); + + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vaddps(ysum, ysum, yd); + vaddps(ysum, ysum, ye); + + vfmadd231ps(ydiffsrc, ysum, ysrc); + + vmovups(ptr[diffsrc], ydiffsrc); + + add(src, 32); + add(diffsrc, 32); + add(diffdst, 32); + add(workspace, 32); + + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_bwd_kernel_f32; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp new file mode 100644 index 0000000000..2b3ed43cd4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_KERNEL_F32_HPP +#define CPU_JIT_UNI_LRN_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +enum params { VECTOR_LENGTH = 8, MAX_LOCAL_SIZE = 32 }; + +typedef struct { + const float *src; + float *dst, *scratch; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *scratch; + float *diff_src; +} jit_args_bwd_t; + +struct nchw8c_across { + /* version: + * -1: channels 0..7, + * 1: channels C-8 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nchw8c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct nchw8c_within { + int H, W, size; + nchw8c_within(int h, int w, int s) : H(h), W(w), size(s) {} +}; + +struct nchw_across { + int C, HW, tail; + nchw_across(int c, int hw, int t) : C(c), HW(hw), tail(t) {} +}; + +struct nhwc_across { + int C; + nhwc_across(int c) : C(c) {} +}; + +template +struct jit_uni_lrn_fwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 dst = r8; + Xbyak::Reg64 scratch = rdx; + Xbyak::Reg64 imm_addr64 = rbx; + Xbyak::Reg64 store_addr = rbp; + + Xbyak::Xmm xalpha = xmm0; + Xbyak::Ymm yalpha = ymm0; + Xbyak::Xmm xk = xmm1; + Xbyak::Ymm yk = ymm1; + + float alpha; + float k; + + int stack_space_needed = 11 * 4 * sizeof(float) + 16; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_f32) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk); + void within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk); + + + void nchw_body(int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum); + void nchw_body_sse42(int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi); + void nchw_tail_sse42(int tail, Xbyak::Reg64 reg_dst, + Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi); + + void operator()(jit_args_fwd_t *arg) { ker(arg); } + void(*ker)(jit_args_fwd_t *); +}; + +template +struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 diffsrc = r8; + Xbyak::Reg64 diffdst = r9; + Xbyak::Reg64 workspace = rdx; + Xbyak::Reg64 imm_addr64 = rsi; + + Xbyak::Xmm xnalphabeta = xmm0; + Xbyak::Ymm ynalphabeta = ymm0; + + float nalphabeta; + + int use_h_parallelizm; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_f32) + + jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void operator()(jit_args_bwd_t *arg) { ker(arg); } + void(*ker)(jit_args_bwd_t *); +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp new file mode 100644 index 0000000000..bf8e609d23 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp @@ -0,0 +1,699 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "cpu_pooling_pd.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; +using namespace alg_kind; + +#define GET_OFF(field) offsetof(jit_pool_call_s, field) + +template +status_t jit_uni_pool_kernel_f32::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d( + ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); + const memory_desc_wrapper dst_d( + ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); + + bool args_ok = true + && mayiuse(isa) + && utils::one_of(pd.alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding); + if (!args_ok) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + const int ndims = src_d.ndims(); + + jpp.ndims = ndims; + jpp.mb = src_d.dims()[0]; + + jpp.c = utils::rnd_up(src_d.dims()[1], simd_w); + if (jpp.c > src_d.padded_dims()[1]) + return status::unimplemented; + + jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jpp.ih = src_d.dims()[ndims-2]; + jpp.iw = src_d.dims()[ndims-1]; + jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jpp.oh = dst_d.dims()[ndims-2]; + jpp.ow = dst_d.dims()[ndims-1]; + + jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1; + jpp.stride_h = pd.strides[ndims-4]; + jpp.stride_w = pd.strides[ndims-3]; + jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; + jpp.kh = pd.kernel[ndims-4]; + jpp.kw = pd.kernel[ndims-3]; + + jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0; + jpp.t_pad = pd.padding[0][ndims-4]; + jpp.l_pad = pd.padding[0][ndims-3]; + + jpp.alg = pd.alg_kind; + + jpp.is_training = pd.prop_kind == prop_kind::forward_training; + jpp.is_backward = pd.prop_kind == prop_kind::backward_data; + jpp.ind_dt = ppd->workspace_md() + ? ppd->workspace_md()->data_type : data_type::undef; + + jpp.simple_alg = jpp.is_training + || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); + + jpp.c_block = simd_w; + + jpp.nb_c = jpp.c / jpp.c_block; + if (jpp.alg == pooling_max) { + jpp.ur_w = isa == avx512_common ? 16 : 4; + if (jpp.is_training) + jpp.ur_w = isa == avx512_common ? 9 : 3; + else if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 6 : 3; + } else { + if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 12 : 6; + else + jpp.ur_w = isa == avx512_common ? 24 : 12; + } + if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow; + if (jpp.l_pad > jpp.ur_w) return status::unimplemented; + + jpp.ur_w_tail = jpp.ow % jpp.ur_w; + + return status::success; +} + +template +inline void jit_uni_pool_kernel_f32::maybe_recalculate_divisor(int jj, + int ur_w, int pad_l, int pad_r) { + if (jpp.alg == pooling_avg_exclude_padding) { + int kw = jpp.kw; + int stride_w = jpp.stride_w; + + int non_zero_kw = kw; + non_zero_kw -= nstl::max(0, pad_l - jj*stride_w); + non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w); + + if (non_zero_kw != prev_kw) { + mov(tmp_gpr, float2int((float)non_zero_kw)); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h); + prev_kw = non_zero_kw; + } + } +} + +template +inline void jit_uni_pool_kernel_f32::avg_step(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + if (jpp.is_backward) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + } else { + uni_vpxor(vreg(jj), vreg(jj), vreg(jj)); + } + } + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + if (jpp.is_backward) { + uni_vmovups(vreg(ur_w+jj), + ptr[aux_reg_input + input_offset]); + uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj)); + uni_vmovups(vmmword[aux_reg_input + input_offset], + vreg(ur_w+jj)); + } else { + uni_vaddps(vreg(jj), vreg(jj), + ptr[aux_reg_input + input_offset]); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + if (!jpp.is_backward) { + for (int jj = 0; jj < ur_w; jj++) { + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], + vreg(jj)); + } + } +} + +template +inline void jit_uni_pool_kernel_f32::max_step_fwd(int ur_w, int pad_l, + int pad_r) { + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + mov(tmp_gpr, float2int(nstl::numeric_limits::lowest())); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), vmm_tmp); + if (jpp.is_training) + uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj)); + } + if (jpp.is_training) + { + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + } + + if (jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + movups(vmm_mask, vreg(jj)); + cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os); + blendvps(vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + blendvps(vreg(2*ur_w+jj), vmm_k_offset); + } else if (isa == avx) { + vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj), + _cmp_lt_os); + vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj), + vreg(3*ur_w+jj)); + if (jpp.is_training) + vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), + vmm_k_offset, vreg(3*ur_w+jj)); + } else { + vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os); + vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + vblendmps(vreg(2*ur_w+jj) | k_store_mask, + vreg(2*ur_w+jj), vmm_k_offset); + } + } + if (jpp.is_training) { + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + if (jpp.is_training) { + mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, xmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj)); + if (jpp.is_training) { + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + + auto x = xreg(2 * ur_w + jj); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + for (int i = 0; i < 4; ++i) + pextrb(ptr[reg_index + step_index + i], x, 4*i); + } else if (isa == avx) { + auto y = yreg(2 * ur_w + jj); + if (jj == 0) { + movd(xmm_tmp, reg_shuf_mask); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (mayiuse(avx2)) { + vpshufb(y, y, vmm_tmp); + movd(ptr[reg_index + step_index], x); + vperm2i128(y, y, y, 0x1u); + movd(ptr[reg_index + step_index + 4], x); + } else { + Xmm t(vmm_mask.getIdx()); + vextractf128(t, y, 0); + vpshufb(t, t, xmm_tmp); + movd(ptr[reg_index + step_index], t); + vextractf128(t, y, 1); + vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] + movd(ptr[reg_index + step_index + 4], t); + } + } else { + auto v = vreg(2 * ur_w + jj); + vpmovusdb(x, v); + vmovups(ptr[reg_index + step_index], v | k_index_mask); + } + } else { + uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj)); + } + } + } +} + +template +inline void jit_uni_pool_kernel_f32::max_step_bwd(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + movd(xreg(ur_w+jj), ptr[reg_index + step_index]); + pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } else if (isa == avx) { + movq(xreg(ur_w+jj), ptr[reg_index + step_index]); + if (!mayiuse(avx2)) { + avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp); + } else { + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + vmovups(vreg(ur_w+jj) | k_index_mask, + ptr[reg_index + step_index]); + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]); + } + } + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + push(dst_ptr); + } + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + mov(dst_ptr, aux_reg_input); + add(dst_ptr, input_offset); + + movups(vreg(3*ur_w+jj), vreg(ur_w+jj)); + pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset); + addps(vreg(2*ur_w+jj), vreg(jj)); + maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj)); + } else if (isa == avx) { + if (mayiuse(avx2)) { + vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset); + } else { + avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp); + } + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj)); + vmaskmovps(vmmword[aux_reg_input + input_offset], + vreg(3*ur_w+jj), vreg(2*ur_w+jj)); + } else { + vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset); + vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj)); + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp); + vmovups(vmmword[aux_reg_input + + sizeof(float)*aux_input_offset], vreg(2*ur_w+jj)); + } + } + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + + mov(tmp_gpr, reg_kd_pad_shift); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, vmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + pop(dst_ptr); + } + pop(reg_output); + pop(reg_input); + } +} + +template +void jit_uni_pool_kernel_f32::maybe_zero_diff_src() { + assert(jpp.c_block * sizeof(float) % cpu_isa_traits::vlen == 0); + Label l_skip, l_zero; + + auto reg_oh = tmp_gpr; + mov(reg_oh, ptr[reg_param + GET_OFF(oh)]); + cmp(reg_oh, 0); + jz(l_skip, T_NEAR); + + if (jpp.ndims == 5) { + mov(zero_size, ptr[reg_param + GET_OFF(oh)]); + mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float)); + imul(zero_size, tmp_gpr); + } + + auto vzero = vmm_tmp; + uni_vpxor(vzero, vzero, vzero); + + auto reg_off = tmp_gpr; + xor_(reg_off, reg_off); + + L(l_zero); + { + const int dim = jpp.iw * jpp.c_block * sizeof(float); + for (int i = 0; i < dim; i += cpu_isa_traits::vlen) + uni_vmovups(ptr[reg_input + reg_off + i], vzero); + add(reg_off, dim); + if (jpp.ndims == 5) cmp(reg_off, zero_size); + else cmp(reg_off, jpp.ih * dim); + jl(l_zero, T_NEAR); + } + + L(l_skip); +} + +template +void jit_uni_pool_kernel_f32::generate() { + + this->preamble(); + + int ow = jpp.ow; + int iw = jpp.iw; + int kw = jpp.kw; + int kh = jpp.kh; + int ur_w = jpp.ur_w; + int c_block = jpp.c_block; + int stride_w = jpp.stride_w; + int l_pad = jpp.l_pad; + int ur_w_tail = jpp.ur_w_tail; + + int n_oi = ow / ur_w; + + prev_kw = 0; + + int vlen = cpu_isa_traits::vlen; + +#if defined(_WIN32) + // Always mimic the Unix ABI (see the note about maskmovdqu in the header + // file). + xor_(rdi, rcx); + xor_(rcx, rdi); + xor_(rdi, rcx); +#endif + + mov(reg_input, ptr[reg_param + GET_OFF(src)]); + mov(reg_output, ptr[reg_param + GET_OFF(dst)]); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + mov(reg_index, ptr[reg_param + GET_OFF(indices)]); + mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]); + mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]); + mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]); + + if (jpp.is_backward) + maybe_zero_diff_src(); + + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { + mov(tmp_gpr, 1); + movq(xmm_one, tmp_gpr); + uni_vpbroadcastd(vmm_one, xmm_one); + + if (isa == avx) { + mov(reg_shuf_mask, 0x0c080400); + } else if (isa >= avx512_common) { + mov(tmp_gpr.cvt32(), 0x000f); + kmovw(k_index_mask, tmp_gpr.cvt32()); + } + } + + int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1)); + int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (jpp.alg == pooling_avg_exclude_padding) { + movq(xmm_ker_area_h, reg_ker_area_h); + uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h); + } + + if (jpp.alg == pooling_avg_include_padding) { + mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd))); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) { + step(ur_w, l_pad, r_pad1); + } else { + step(ur_w, l_pad, 0); + } + + if (isa == sse42) { + if (n_oi < 0 && r_pad1 > 0) { + step_high_half(ur_w, l_pad, r_pad1); + } else { + step_high_half(ur_w, l_pad, 0); + } + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + xor_(oi_iter, oi_iter); + if (n_oi > 0) { + Label ow_loop; + L(ow_loop); { + step(ur_w, 0, 0); + + if (isa == sse42) { + step_high_half(ur_w, 0, 0); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + } + + if (r_pad1 > 0 && n_oi >= 0) { + step(ur_w, 0, r_pad1); + + if (isa == sse42) { + step_high_half(ur_w, 0, r_pad1); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + if (ur_w_tail != 0) { + step(ur_w_tail, 0, r_pad); + + if (isa == sse42) { + step_high_half(ur_w_tail, 0, r_pad); + } + } + + this->postamble(); +} + +template struct jit_uni_pool_kernel_f32; +template struct jit_uni_pool_kernel_f32; // implements both and +template struct jit_uni_pool_kernel_f32; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp new file mode 100644 index 0000000000..992b526587 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp @@ -0,0 +1,192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_POOL_KERNEL_F32_HPP +#define JIT_UNI_POOL_KERNEL_F32_HPP + +#include + +#include "c_types_map.hpp" +#include "pooling_pd.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template +struct jit_uni_pool_kernel_f32: public jit_generator { + jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + jit_pool_conf_t jpp; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32) + + void operator()(jit_pool_call_s *arg) { jit_ker(arg); } + static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd); + +private: + using Vmm = typename utils::conditional3::type; + Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); } + Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); } + + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx) ? yword : zword; + + Xmm vmm_mask = Xmm(0); + Xmm xmm_ker_area_h = Xmm(2); + Xmm xmm_one = Xmm(2); + Xmm xmm_tmp = Xmm(3); + + Vmm vmm_ker_area_h = Vmm(2); + Vmm vmm_one = Vmm(2); + Vmm vmm_tmp = Vmm(3); + + Vmm vmm_k_offset = Vmm(1); + + Opmask k_index_mask = Opmask(6); + Opmask k_store_mask = Opmask(7); + + // Here be some (tame) dragons. This kernel does not follow the regular + // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu + // instruction which has its destination hardcoded in rdi. Therefore: + // - all registers are hardcoded + // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI + // + // While this is only required by the backward pass, the quirk above + // is applied to the forward pass as well to keep things simpler. + + using reg64_t = const Xbyak::Reg64; + reg64_t reg_param = rdi; // Always mimic the Unix ABI + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t reg_index = r10; + reg64_t reg_output = r12; + reg64_t reg_kd_pad_shift = r13; + reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu + + reg64_t kj = r14; + reg64_t oi_iter = r15; + reg64_t reg_kh = rax; + reg64_t reg_k_shift = rbx; + reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above + reg64_t reg_ker_area_h = rdx; + + reg64_t zero_size = r15; + reg64_t ki = r12; + reg64_t aux_reg_input_d = r8; + + Xbyak::Reg32 reg_shuf_mask = esi; + + int prev_kw; + void (*jit_ker)(jit_pool_call_s *); + + void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r); + void avg_step(int ur_w, int pad_l, int pad_r); + void max_step_fwd(int ur_w, int pad_l, int pad_r); + void max_step_bwd(int ur_w, int pad_l, int pad_r); + + void maybe_zero_diff_src(); + + void step(int ur_w, int pad_l, int pad_r) { + if (jpp.alg == alg_kind::pooling_max) { + if(jpp.is_backward) + max_step_bwd(ur_w, pad_l, pad_r); + else + max_step_fwd(ur_w, pad_l, pad_r); + } + else + avg_step(ur_w, pad_l, pad_r); + } + + void step_high_half(int ur_w, int pad_l, int pad_r) { + add(reg_input, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + if (jpp.alg == alg_kind::pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, types::data_type_size(jpp.ind_dt) * 4); + + step(ur_w, pad_l, pad_r); + } + + void generate(); + + void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + assert(y0.getIdx() != x1.getIdx()); + vextractf128(xtmp, y0, 0); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 0); + vextractf128(xtmp, y0, 1); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + paddd(x0, x1); + } + + void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + Xmm x0(y0.getIdx()); + pshufd(xmm_tmp, x1, 1); + pmovzxbd(x0, x1); + pmovzxbd(xmm_tmp, xmm_tmp); + vinsertf128(y0, y0, xmm_tmp, 1); + } + + void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + pmovzxbd(x0, x1); + } + + void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) { + assert(y0.getIdx() != y1.getIdx()); + assert(y0.getIdx() != y2.getIdx()); + Xmm x0(y0.getIdx()); + Xmm x2(y2.getIdx()); + vextractf128(x0, y1, 1); + vextractf128(xtmp, y2, 1); + pcmpeqd(xtmp, x0); + vextractf128(x0, y1, 0); + pcmpeqd(x0, x2); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) { + assert(false /*function should not be used*/); + pcmpeqd(x0, x1); + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp new file mode 100644 index 0000000000..afbcf996d8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" + +#include "jit_uni_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = oh == 0; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.oh, + [&](int n, int b_c, int oh) { + ker(n, b_c, oh); + }); +} + +template +void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, id, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh + od == 0); + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad) + -jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow); + } + }); +} + +template +void jit_uni_pooling_bwd_t::execute_backward(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh == 0); + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, oh); + } + }); +} + +template +void jit_uni_pooling_bwd_t::execute_backward_3d(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow, int zero_size, int kd) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = zero_size; + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh + + kd * jpp.kw * jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + (*kernel_)(&arg); + }; + + if (jpp.simple_alg) { + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + int zero_s = jpp.stride_d - d_t_overflow - (nstl::max( + jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + (oh == 0) ? zero_s : 0, 0); + } + }); + } else { + ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c + * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw; + + parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; }); + + for (int kd = 0; kd < jpp.kd; ++kd) { + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int od = 0; od < jpp.od; ++od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + if (kd >= jpp.kd - d_t_overflow - d_b_overflow) + continue; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + 0, kd); + } + } + }); + } + } +} + + +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp new file mode 100644 index 0000000000..57bebacdee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp @@ -0,0 +1,182 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_POOLING_HPP +#define CPU_JIT_UNI_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_fwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return jit_uni_pool_kernel_f32::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() == 4 + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } + + ~jit_uni_pooling_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE); + + if (pd()->ndims() == 5) + execute_forward_3d(src, dst, ws); + else + execute_forward(src, dst, ws); + + return status::success; + } + +private: + void execute_forward(const data_t *src, data_t *dst, char *indices) const; + void execute_forward_3d(const data_t *src, data_t *dst, + char *indices) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32 *kernel_; +}; + +template +struct jit_uni_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_bwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + diff_src_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return jit_uni_pool_kernel_f32::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } + + ~jit_uni_pooling_bwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + if (pd()->ndims() == 5) + execute_backward_3d(diff_dst, ws, diff_src); + else + execute_backward(diff_dst, ws, diff_src); + + return status::success; + } + +private: + void execute_backward(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + void execute_backward_3d(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32 *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp new file mode 100644 index 0000000000..98796503b7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp @@ -0,0 +1,1006 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "jit_uni_reorder.hpp" + +#include "jit_generator.hpp" + +// #define TR_DEBUG +#if defined(TR_DEBUG) +#define DEBUg(...) do { __VA_ARGS__ } while (0) +#else +#define DEBUg(...) +#endif +#define DEBUG(...) DEBUg(__VA_ARGS__) + +#ifdef _WIN32 +/* seems like s_addr is a reserved macro on Windows */ +#undef s_addr +#endif + +using namespace Xbyak; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** Minimal reasonable/desirable kernel size. + * The constant might be used to determine how a problem should be split + * between kernel and threading driver. */ +const size_t ker_prb_size_min = 64; + +/* kernel */ +struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) + + enum { + len_unroll_max = 256, + ndims_jit_loop_max = 3, + }; + + struct simple_impl_desc_t { + int ndims_full_unroll; + int len_last_dim_unroll; + int len_unroll; + }; + + static bool simple_impl_desc_init(const prb_t &prb, + simple_impl_desc_t *desc) { + const int ndims = prb.ndims; + + int ndims_full_unroll = 0; + int len_last_dim_unroll = 1; + int len_unroll = 1; + + for (int d = 0; d < ndims; ++d) { + auto &node = prb.nodes[d]; + if (len_unroll * node.n <= len_unroll_max) { + ndims_full_unroll++; + len_unroll *= node.n; + } else { + len_last_dim_unroll = len_unroll_max / len_unroll; + while (node.n % len_last_dim_unroll) + --len_last_dim_unroll; + len_unroll *= len_last_dim_unroll; + break; + } + } + + if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) + return false; + + if (desc) { + desc->ndims_full_unroll = ndims_full_unroll; + desc->len_last_dim_unroll = len_last_dim_unroll; + desc->len_unroll = len_unroll; + } + + return true; + } + + static bool applicable(const prb_t &p) { + using namespace data_type; + + bool ok = true + && p.ndims > 0 + && utils::one_of(p.itype, f32, s32, s8, u8) + && utils::one_of(p.otype, f32, s32, s8, u8) + && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ + && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ + && simple_impl_desc_init(p, nullptr) + && mayiuse(sse42) + && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype), + mayiuse(avx)); + if (!ok) return false; + + const ptrdiff_t max_stride = (1LL<<31) - 1; + for (int d = 0; d < p.ndims; ++d) { + const ptrdiff_t cms = max_stride / p.nodes[d].n; + bool strides_ok = true + && p.nodes[d].is < cms / (int)data_type_size(p.itype) + && p.nodes[d].os < cms / (int)data_type_size(p.otype); + if (!strides_ok) return false; + } + + return true; + } + + int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; } + int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; } + int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; } + int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; } + + Address i_addr(int i_off) + { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; } + + Address o_addr(int o_off) + { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; } + + Address s_addr(int s_off) + { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; } + + void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, + int &i_off, int &o_off, int &s_off, int step_size = 1) { + i_off = prev_i_off; + o_off = prev_o_off; + s_off = prev_s_off; + + if (off == 0) return; + + int start_dim = 0, dims_prod = 1; + for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) + dims_prod *= n(start_dim); + assert(start_dim < prb_.ndims); + off /= step_size; + + for (int d = start_dim; d < prb_.ndims; ++d) { + i_off += is(d); + o_off += os(d); + s_off += ss(d); + + if (off % n(d)) break; + + i_off += - n(d) * is(d); + o_off += - n(d) * os(d); + s_off += - n(d) * ss(d); + off /= n(d); + + if (off == 0) break; /* FIXME: is it really required? */ + } + } + + void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, + int step_size = 1) { + int dummy = 0; + step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, + step_size); + } + + void tr8x8_avx2(int i_off, int o_off) { + for (int i = 0; i < 8; i++) + vmovups(Ymm(i), i_addr(i_off + i * 8)); + + for (int i = 0; i < 8 / 2; i++) { + vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1)); + vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); + } + + const unsigned int lfloat = 0x44; + const unsigned int ufloat = 0xee; + for (int i = 0; i < 8 / 2; i++) { + int j = i % 2 == 0 ? 8 + i : i - 1; + vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); + vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); + } + + const unsigned int lquad = 0x20; + for (int i = 0; i < 8 / 2; i++) + vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad); + + const unsigned int uquad = 0x31; + for (int i = 8 / 2; i < 8; i++) + vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad); + + for (int i = 0; i < 8; i++) + vmovups(o_addr(o_off + i * 8), Ymm(i)); + } + + bool process_unroll_tr8x8(int len) { + bool can_do = true + && mayiuse(avx2) + && prb_.ndims >= 2 + && utils::everyone_is(4, itype_sz, otype_sz) + && utils::everyone_is(8, n(0), n(1)) + && utils::everyone_is(1, os(0), is(1)) + && utils::everyone_is(8, os(1), is(0)) + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + const int step_size = n(0) * n(1); + int i_off = 0, o_off = 0; + for (int off = 0; off < len; off += step_size) { + step(off, i_off, o_off, i_off, o_off, step_size); + tr8x8_avx2(i_off, o_off); + } + + return true; + } + + template + bool process_direct_copy(int len) { + using namespace data_type; + + using Vmm = typename cpu_isa_traits::Vmm; + const int simd_w = cpu_isa_traits::vlen / itype_sz; + + bool can_do = true + && mayiuse(isa) + && utils::everyone_is(1, os(0), is(0)) + && (false + || prb_.itype == prb_.otype + || (prb_.itype == s32 && prb_.otype == f32) + || (prb_.itype == f32 && prb_.otype == s32) + ) + && len % simd_w == 0 + && n(0) % len == 0 + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + for (int off = 0; off < len;) { + const int unroll = nstl::min(16, (len - off) / simd_w); + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); + + if (prb_.itype != prb_.otype) { + for (int ur = 0; ur < unroll; ++ur) { + if (prb_.itype == s32 && prb_.otype == f32) + uni_vcvtdq2ps(Vmm(ur), Vmm(ur)); + else if (prb_.itype == f32 && prb_.otype == s32) + uni_vcvtps2dq(Vmm(ur), Vmm(ur)); + else assert(!"unreachable"); + } + } + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur)); + + off += unroll * simd_w; + } + + return true; + } + + void process_unroll_generic_step(int reg_unroll, const int *i_off, + const int *o_off, const int *s_off) { + using namespace data_type; + + auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) { + Xmm dst_pure = Xmm(dst.getIdx()); + switch (idt) { + case f32: + if (src.isMEM() || src.getIdx() != dst.getIdx()) + vmovups(dst, src); + break; + case s32: vcvtdq2ps(dst, src); break; + case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + default: assert(!"unreachable"); + } + }; + + auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) { + switch (odt) { + case s32: + if (idt == f32) vcvtps2dq(xmm, xmm); + else if (idt == s8) vpmovsxbd(xmm, xmm); + else if (idt == u8) vpmovzxbd(xmm, xmm); + break; + case s8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmovsdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpacksswb(xmm, xmm, xmm_zero); + } + } + if (idt == u8) vpminub(xmm, xmm, xmm_4x127b); + break; + case u8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmaxsd(xmm, xmm, xmm_zero); + vpmovusdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpackuswb(xmm, xmm, xmm_zero); + } + } + if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero); + break; + default: assert(!"unreachable"); + } + }; + + auto load = [=](const Xmm &xmm, const Address &addr, int size) { + switch (size) { + case 16: movups(xmm, addr); break; + case 4: movss(xmm, addr); break; + case 1: pinsrb(xmm, addr, 0x0); break; + default: assert(!"unreachable"); + } + }; + + auto store = [=](const Address &addr, const Xmm &xmm, int size) { + switch (size) { + case 16: movups(addr, xmm); break; + case 4: movss(addr, xmm); break; + case 1: pextrb(addr, xmm, 0x0); break; + default: assert(!"unreachable"); + } + }; + + /* check whether loading 4 values at once is possible */ + bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (i_off[ur] != i_off[ur - 1] + 1) + can_load_xmm = false; + const int load_step = can_load_xmm ? 4 : 1; + + /* check whether storing 4 values at once is possible */ + bool can_store_xmm = reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (o_off[ur] != o_off[ur - 1] + 1) + can_store_xmm = false; + const int ur_step = can_store_xmm ? 4 : 1; + + const bool interim_f32 = false + || utils::one_of(f32, prb_.itype, prb_.otype) + || prb_.scale_type != scale_type_t::NONE + || prb_.beta != 0.f; + + if (!can_load_xmm && can_store_xmm) { + assert(ur_step == 4); + /* load with stride */ + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + for (int r = 0; r < ur_step; ++r) { + if (itype_sz == 4) + pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r); + else + pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r); + } + } + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); + } + + /* xmm[:] <-- (f32)xmm[:] */ + if (interim_f32) { + const int cvt_step = nstl::max(load_step, ur_step); + for (int ur = 0; ur < reg_unroll; ur += cvt_step) + cvt2ps(Xmm(ur), Xmm(ur), prb_.itype); + } + + if (can_load_xmm && !can_store_xmm) { + const bool fast_return = true // transposition on the fly + && prb_.scale_type != scale_type_t::MANY + && prb_.beta == 0.f; + if (fast_return) { + for (int ur = 0; ur < reg_unroll; ur += load_step) { + if (prb_.scale_type == scale_type_t::COMMON) + mulps(Xmm(ur), xmm_scale); + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, + interim_f32 ? f32 : prb_.itype); + for (int r = 0; r < load_step; ++r) { + if (otype_sz == 4) + pextrd(o_addr(o_off[ur + r]), Xmm(ur), r); + else + pextrb(o_addr(o_off[ur + r]), Xmm(ur), r); + } + } + return; + } + + /* scatter elements of xmm into 4 xmms */ + if (itype_sz == 4 || interim_f32) { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } + } + + /* scale and beta processing */ + if (can_store_xmm) { + /* xmm <-- scale * xmm[:] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulps(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + enum class scale_load_type_t { bcast, load, gather }; + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + scale_load_type_t scale_load_type = + scale_load_type_t::bcast; // the best case + + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 0) + scale_load_type = scale_load_type_t::load; + + if (scale_load_type == scale_load_type_t::bcast) { + movss(xmm_scale, s_addr(s_off[ur])); + shufps(xmm_scale, xmm_scale, 0x0); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // bcast doesn't work, the next try -- load + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 1) + scale_load_type = scale_load_type_t::gather; + + if (scale_load_type == scale_load_type_t::load) { + movups(xmm_scale, s_addr(s_off[ur])); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // load doesn't work as well + // so gather the scale factors one by one + for (int r = ur; r < ur + ur_step; ++r) + pinsrd(xmm_scale, s_addr(s_off[r]), r - ur); + mulps(Xmm(ur), xmm_scale); + } + } + + /* dst <-- beta * dst + xmm[:] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + /* non VEX instructions do not support unaligned + * memory for instructions other than movups. */ + if (mayiuse(avx)) { + vaddps(Xmm(ur), o_addr(o_off[ur])); + } else { + /* register xmm(1) is unused */ + movups(Xmm(1), o_addr(o_off[ur])); + addps(Xmm(ur), Xmm(1)); + } + } else { + cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype); + vaddps(Xmm(ur), Xmm(1)); + } + } + } + } else { + /* xmm[0] <-- scale * xmm[0] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulss(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + mulss(Xmm(ur), s_addr(s_off[ur])); + } + } + + /* dst <-- beta * dst + xmm[0] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + addss(Xmm(ur), o_addr(o_off[ur])); + } else { + if (prb_.otype == s32) { + vmovss(xmm_tmp, o_addr(o_off[ur])); + } else if (utils::one_of(prb_.otype, s8, u8)) { + pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0); + } else { + assert(!"unsupported o_type"); + } + cvt2ps(xmm_tmp, xmm_tmp, prb_.otype); + addps(Xmm(ur), xmm_tmp); + } + } + } + } + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype); + store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz); + } + } + + void process_unroll_generic(int len) { + const int blk = 8; + + int i_off[2 * blk] = {0}; + int o_off[2 * blk] = {0}; + int s_off[2 * blk] = {0}; + + int curr = 0; // will switch between 0 and 1 + + for (int off = 0; off < len; off += blk) { + const int reg_unroll = nstl::min(off + blk, len) - off; + + /* compute offsets */ + for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { + const int ur_c = curr * blk + ur; + const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur + step(off + ur, + i_off[ur_p], o_off[ur_p], s_off[ur_p], + i_off[ur_c], o_off[ur_c], s_off[ur_c]); + } + + process_unroll_generic_step(reg_unroll, i_off + curr * blk, + o_off + curr * blk, s_off + curr * blk); + + curr = 1 - curr; + } + } + + void loop_begin(Label &l, Reg64 reg_cnt, int len) { + mov(reg_cnt, len); + L(l); + } + + void loop_end(Label &l, Reg64 reg_cnt, int len, + int i_step, int o_step, int s_step) { + add(reg_off_in, i_step * itype_sz); + add(reg_off_out, o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + add(reg_off_scale, s_step * stype_sz); + dec(reg_cnt); + jnz(l); + + sub(reg_off_in, len * i_step * itype_sz); + sub(reg_off_out, len * o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + sub(reg_off_scale, len * s_step * stype_sz); + } + + bool simple_impl() { + simple_impl_desc_t d; + if (!simple_impl_desc_init(prb_, &d)) return false; + + const int nfu = d.ndims_full_unroll; + const int ldu = d.len_last_dim_unroll; + const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; + assert(n_jit_loops <= ndims_jit_loop_max); + + xor_(reg_off_in, reg_off_in); + xor_(reg_off_out, reg_off_out); + if (prb_.scale_type == scale_type_t::MANY) + xor_(reg_off_scale, reg_off_scale); + + Label l_loop[3]; + Reg64 reg_cnt[3] = {r15, r14, r13}; + + if (n_jit_loops > 2) + loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); + + if (n_jit_loops > 1) + loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); + + if (n_jit_loops > 0) + loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); + + const bool optimized = false + || process_direct_copy(d.len_unroll) + || process_direct_copy(d.len_unroll) + || process_unroll_tr8x8(d.len_unroll); + if (!optimized) + process_unroll_generic(d.len_unroll); + + if (n_jit_loops > 0) + loop_end(l_loop[0], reg_cnt[0], + n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu, + ss(nfu + 0) * ldu); + + if (n_jit_loops > 1) + loop_end(l_loop[1], reg_cnt[1], + n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1)); + + if (n_jit_loops > 2) + loop_end(l_loop[2], reg_cnt[2], + n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2)); + + return true; + } + + void impl() { + if (simple_impl()) return; + assert(!"no implementation available"); + } + + jit_uni_reorder_kernel_f32(const desc_t &desc) + : kernel_t(desc), jit_generator() { + itype_sz = data_type_size(prb_.itype); + otype_sz = data_type_size(prb_.otype); + stype_sz = sizeof(float); + + preamble(); +# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)] + if (prb_.scale_type == scale_type_t::COMMON) { + auto reg_ptr_scale_tmp = reg_ptr_in; + mov(reg_ptr_scale_tmp, PARAM(scale)); + movups(xmm_scale, ptr[reg_ptr_scale_tmp]); + } else if (prb_.scale_type == scale_type_t::MANY) { + mov(reg_ptr_scale, PARAM(scale)); + } + mov(reg_ptr_in, PARAM(in)); + mov(reg_ptr_out, PARAM(out)); +# undef PARAM + + if (mayiuse(avx)) { + vxorps(xmm_zero, xmm_zero, xmm_zero); + + if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { + mov(reg_tmp.cvt32(), 0x7f7f7f7f); + movd(xmm_4x127b, reg_tmp.cvt32()); + } + } + + impl(); + postamble(); + ker_ = (void (*)(const call_param_t *))getCode(); + } + +private: + int itype_sz; + int otype_sz; + int stype_sz; + + Reg64 reg_ptr_in = rsi; + Reg64 reg_ptr_out = rdx; + Reg64 reg_ptr_scale = abi_not_param1; + + Reg64 reg_off_in = r8; + Reg64 reg_off_out = r9; + Reg64 reg_off_scale = r10; + + Reg64 reg_tmp = rax; + + Xmm xmm_scale = xmm15; + Xmm xmm_zero = xmm14; + Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero + Xmm xmm_tmp = xmm12; +}; + +status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb, + int ndims_ker_max) { + desc.prb = prb; + desc.prb.ioff = desc.prb.ooff = 0; + + if (ndims_ker_max > prb.ndims) + return status::invalid_arguments; + + auto ndims_ker_max_f = [&]() { + size_t cur_size = 1; + for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) + if (cur_size >= ker_prb_size_min) return d; + return prb.ndims; + }; + + if (ndims_ker_max <= 0) + ndims_ker_max = ndims_ker_max_f(); + + /* traverse through kernel implementations */ + /* TODO: find a better way to do that... */ + desc.id = 0; + for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { + desc.prb.ndims = ndims_ker; + if (jit_uni_reorder_kernel_f32::applicable(desc.prb)) + return status::success; + } + + return status::unimplemented; +} + +kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { + switch (desc.id) { + case 0: return new jit_uni_reorder_kernel_f32(desc); + default: assert(!"unknown kernel id"); return nullptr; + } + + return nullptr; +} + +} + +static void prb_block_for_cache(tr::prb_t &prb) { + if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) { + /** an attempt to use caches more efficient and + * address the 4K-aliasing issue */ + /* TODO: improve the logic around here */ + int j = 1; + for (; j < prb.ndims && prb.nodes[j].is != 1; ++j); + if (j == prb.ndims) return; + + /* it makes sense to re-prioritize sequential read over + * sequential write if the former would not trash the + * cache, i.e. is == 1 and os % 2^smth != 0. Smth is + * set to 2 at the moment */ + const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1; + if (j == move_to) return; + + if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0) + prb_node_split(prb, j, 16); + + prb_node_move(prb, j, move_to); + DEBUG({ printf("cache: "); prb_dump(prb); }); + } +} + +/** finds the maximum number of dimension the kernel should process and + * optionally splits one of the dimension to achieve better balance between + * parallel driver and the kernel. */ +static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) { + size_t sz_total = 1; + for (int d = 0; d < prb.ndims; ++d) + sz_total *= prb.nodes[d].n; + + /* sz_drv_min is the minimal size for the parallel + * driver required for good parallelization */ + const size_t sz_drv_min = nstl::min( + 16 * mkldnn_get_max_threads(), + utils::div_up(sz_total, 1024)); + + /* kdims -- # of dimensions processed by a kernel + * sz_ker_cur -- product of the dimension processed by a kernel + * sz_drv_cur -- product of the dimension processed by a driver */ + + int kdims = prb.ndims; + size_t sz_drv_cur = 1; + for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) + sz_drv_cur *= prb.nodes[kdims - 1].n; + + size_t sz_ker_cur = 1; + for (int d = 0; d < kdims; ++d) + sz_ker_cur *= prb.nodes[d].n; + + /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. + * + * It might happen that for chosen kdims the sz_ker_cur is too small + * (less than tr::ker_prb_size_min). In that case try to split the + * innermost driver dimension into two, to increase sz_ker_cur. */ + bool want_borrow_ker_from_drv = true + && kdims < prb.ndims + && sz_ker_cur < tr::ker_prb_size_min + && sz_drv_cur > sz_drv_min; + if (want_borrow_ker_from_drv) { + /* sz_want_borrow is the minimal sz, so that: + * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min + * o) current innermost driver dimension is divisible by + * sz_want_borrow (so that we can evenly split that + * dimension into two) + * + * In the worst case the minimal sz_want_borrow is equal + * to the innermost driver dimension itself. In that case + * we will sacrifice it in favor of kernel (is it fine?). */ + size_t sz_want_borrow + = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); + for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims].n) + prb_node_split(prb, kdims, sz_want_borrow); + kdims += 1; + } + + /* On the other hand it might happen that for chosen kdims + * the sz_drv_cur is too small (less than sz_drv_min). In that case + * try to split the outermost kernel dimension into two, to increase + * sz_drv_cur. */ + bool want_borrow_drv_from_ker = true + && sz_ker_cur > tr::ker_prb_size_min + && sz_drv_cur < sz_drv_min; + if (want_borrow_drv_from_ker) { + size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); + for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims - 1].n) + prb_node_split(prb, kdims - 1, + prb.nodes[kdims - 1].n / sz_want_borrow); + } + + ndims_ker_max = kdims; + + if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { + DEBUG({ printf("split: "); prb_dump(prb); + printf("ndims_ker_max = %d\n", ndims_ker_max); }); + } +} + +struct jit_uni_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + auto prb = tr::prb_t(); + + status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); + if (prb_init_status != status::success) return prb_init_status; + + DEBUG({ printf("init : "); prb_dump(prb); }); + prb_normalize(prb); + DEBUG({ printf("norm : "); prb_dump(prb); }); + prb_simplify(prb); + DEBUG({ printf("smpl : "); prb_dump(prb); }); + + prb_block_for_cache(prb); + + int ndims_ker_max; + prb_thread_kernel_balance(prb, ndims_ker_max); + + tr::kernel_t::desc_t ker_desc; + status_t ker_init_status + = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); + if (ker_init_status != status::success) return ker_init_status; + + const int ndims_driver = prb.ndims - ker_desc.prb.ndims; + if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) + return status::unimplemented; + + DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); }); + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + _pd->prb_ = prb; + _pd->ker_desc_ = ker_desc; + return safe_ptr_assign(*reorder_pd, _pd); + } + + tr::prb_t prb_; + tr::kernel_t::desc_t ker_desc_; + }; + + jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = tr::kernel_t::create(pd()->ker_desc_); + assert(kernel_); + } + ~jit_uni_reorder_t() { delete kernel_; } + + void omp_driver_0d(int off, const char *in, char *out, + const float *scale) const { + tr::call_param_t c{in, out, scale}; + (*kernel_)(&c); + } + + void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); + c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, + (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, + (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is + + d3 * ns[3].is) * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os + + d3 * ns[3].os) * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss + + d3 * ns[3].ss; + (*kernel_)(&c); + }); + } + + void omp_driver(const char *in, char *out, const float *scale) const { + in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); + out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); + + DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); }); + DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); }); + + int ndims = pd()->prb_.ndims; + int ndims_ker = pd()->ker_desc_.prb.ndims; + assert(ndims - ndims_ker <= ndims_driver_max); + + if (ndims - ndims_ker == 0) { + omp_driver_0d(ndims_ker, in, out, scale); + } else { + parallel(0, [&](const int ithr, const int nthr) { + switch (ndims - ndims_ker) { + case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break; + case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break; + case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break; + case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break; + default: assert(!"unimplemented"); + } + }); + } + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM); + auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + + omp_driver(in, out, pd()->attr()->output_scales_.scales_); + + return status::success; + } + + enum { ndims_driver_max = 4 }; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + tr::kernel_t *kernel_; +}; + +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr, + src_engine, src_md, dst_engine, dst_md); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp new file mode 100644 index 0000000000..0746ea61d3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef _JIT_UNI_REORDER_HPP +#define _JIT_UNI_REORDER_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +constexpr int max_ndims = MKLDNN_MAX_NDIMS; + +struct node_t { + size_t n; + ptrdiff_t is; // input stride + ptrdiff_t os; // output stride + ptrdiff_t ss; // scale stride +}; + +enum class scale_type_t { NONE, COMMON, MANY }; + +struct prb_t { + data_type_t itype; + data_type_t otype; + int ndims; + node_t nodes[max_ndims]; + ptrdiff_t ioff; + ptrdiff_t ooff; + scale_type_t scale_type; + float beta; +}; + +status_t prb_init(prb_t &prb, const memory_desc_t &imd, + const memory_desc_t &omd, const primitive_attr_t *attr); + +/** sorts the problem nodes so that output strides come in ascending order */ +void prb_normalize(prb_t &p); + +/** folds nodes together if possible */ +void prb_simplify(prb_t &p); + +/** splits the node dim into two of sizes n1 and n / n1 + * @warning n must be multiple of n1 */ +void prb_node_split(prb_t &p, int dim, size_t n1); + +/** swaps d0 and d1 nodes */ +void prb_node_swap(prb_t &p, int d0, int d1); + +/** moves node d0 to the d1 position. + * nodes (d0, d1] are shifted to the left if d0 < d1 or + * to the right if d0 > d1 */ +void prb_node_move(prb_t &p, int d0, int d1); + +/** dumps the problem to stdout */ +void prb_dump(const prb_t &p); + +struct call_param_t { + const void *in; + void *out; + const float *scale; +}; + +struct kernel_t { + struct desc_t { + int id; + prb_t prb; + }; + + kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {} + void operator()(const call_param_t *c) const { assert(ker_); ker_(c); } + virtual ~kernel_t() {} + + /** inits kernel descriptor: + * desc -- kernel descriptor (output) + * prb -- transposition problem (input) + * ndims_ker_max -- limit the maximum number of dimensions kernel + * will process (optional, 0 -- no limitation) */ + static status_t desc_init(desc_t &desc, const prb_t &prb, + int ndims_ker_max = 0); + + /** creates kernel for the problem described in desc */ + static kernel_t *create(const desc_t &desc); + +protected: + const desc_t desc_; + const prb_t &prb_ = desc_.prb; + void (*ker_)(const call_param_t *); +}; + +/* TODO: add trans_t class */ + +} + +/* for cpu reorder list */ +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp new file mode 100644 index 0000000000..69b7a33604 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp @@ -0,0 +1,313 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_reorder.hpp" + +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::status; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** ad-hoc structure to describe blocked memory layout */ +struct layout_desc_t { + data_type_t dt; + int ndims; + dims_t id; + dims_t dims; + strides_t strides; +}; + +status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, + layout_desc_t &ld) { + const auto md = memory_desc_wrapper(md_); + + bool ok = true + && md.is_blocking_desc() + && md.extra().flags == 0; + if (!ok) return invalid_arguments; + + const auto &bd = md.blocking_desc(); + + ld.ndims = 0; + ld.dt = md.data_type(); + + auto P = [&ld](int id, int dim, ptrdiff_t stride) { + assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); + ld.id[ld.ndims] = id; + ld.dims[ld.ndims] = dim; + ld.strides[ld.ndims] = stride; + ++ld.ndims; + }; + + dims_t blocks; + md.compute_blocks(blocks); + + for (int d = 0; d < md.ndims(); ++d) { + const int ld_ndims_start = ld.ndims; + if (blocks[d] != 1) { + stride_t stride = 1; + for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { + if (bd.inner_idxs[iblk] == d) + P(d, bd.inner_blks[iblk], stride); + stride *= bd.inner_blks[iblk]; + } + } + P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); + + // TODO: NOW: revisit, do we need a reverse? + // TODO: NOW: consider using strides instead of block sizes in md + // reverse the order of dims + for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { + const int idx0 = ld_ndims_start + ld_d; + const int idx1 = ld.ndims - 1 - ld_d; + nstl::swap(ld.dims[idx0], ld.dims[idx1]); + nstl::swap(ld.strides[idx0], ld.strides[idx1]); + } + } + + return success; +} + +status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + const primitive_attr_t *attr) { + auto im_d = memory_desc_wrapper(imd); + auto om_d = memory_desc_wrapper(omd); + + bool ok = true + && im_d.is_blocking_desc() + && om_d.is_blocking_desc() + && !im_d.has_zero_dim() + && !om_d.has_zero_dim(); + if (!ok) + return unimplemented; + + dims_t iblocks, oblocks; + im_d.compute_blocks(iblocks); + om_d.compute_blocks(oblocks); + + /* padding_dim consistency check */ + for (int d = 0; d < im_d.ndims(); ++d) { + const auto pdim = im_d.padded_dims()[d]; + bool ok = true + && pdim == om_d.padded_dims()[d] + && pdim % iblocks[d] == 0 + && pdim % oblocks[d] == 0; + if (!ok) return unimplemented; + } + + layout_desc_t ild, old; + status_t status = cvt_mem_desc_to_layout_desc(imd, ild); + if (status != success) return status; + status = cvt_mem_desc_to_layout_desc(omd, old); + if (status != success) return status; + + p.itype = ild.dt; + p.otype = old.dt; + + p.scale_type = attr->output_scales_.has_default_values() + ? scale_type_t::NONE + : (attr->output_scales_.mask_ == 0 + ? scale_type_t::COMMON + : scale_type_t::MANY); + + ptrdiff_t ss[max_ndims] = {0}; + if (p.scale_type == scale_type_t::MANY) { + ptrdiff_t last_ss = 1; + for (int d = old.ndims - 1; d >=0; --d) { + assert((d == 0 || old.id[d - 1] <= old.id[d]) + && "logical dimensions should be in ascending order"); + if (attr->output_scales_.mask_ & (1 << old.id[d])) { + ss[d] = last_ss; + last_ss *= old.dims[d]; + } + } + } + + int ndims = 0; + + int i_pos = 0; /* state for input -- current dimension */ + int o_pos = 0; /* state for output -- current dimension */ + + while (i_pos < ild.ndims && o_pos < old.ndims) { + assert(ild.id[i_pos] == old.id[o_pos]); + if (ild.id[i_pos] != old.id[o_pos]) + return runtime_error; + + assert(ndims < max_ndims); + if (ndims == max_ndims) + return runtime_error; + + if (ild.dims[i_pos] == old.dims[o_pos]) { + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++i_pos; + ++o_pos; + } else if (ild.dims[i_pos] < old.dims[o_pos]) { + assert(old.dims[o_pos] % ild.dims[i_pos] == 0); + int factor = old.dims[o_pos] / ild.dims[i_pos]; + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos] * factor; + p.nodes[ndims].ss = ss[o_pos] * factor; + ++ndims; + ++i_pos; + old.dims[o_pos] = factor; + } else if (ild.dims[i_pos] > old.dims[o_pos]) { + assert(ild.dims[i_pos] % old.dims[o_pos] == 0); + int factor = ild.dims[i_pos] / old.dims[o_pos]; + p.nodes[ndims].n = old.dims[o_pos]; + p.nodes[ndims].is = ild.strides[i_pos] * factor; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++o_pos; + ild.dims[i_pos] = factor; + } + } + p.ndims = ndims; + + dims_t zero_pos = {0}; + p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); + p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); + + const int sum_idx = attr->post_ops_.find(primitive_kind::sum); + p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; + + return success; +} + +void prb_normalize(prb_t &p) { + for (int d = 0; d < p.ndims; ++d) { + int min_pos = d; + for (int j = d + 1; j < p.ndims; ++j) { + bool new_min = false + || p.nodes[j].os < p.nodes[min_pos].os + || (true + && p.nodes[j].os == p.nodes[min_pos].os + && p.nodes[j].n < p.nodes[min_pos].n); + if (new_min) min_pos = j; + } + if (min_pos != d) + nstl::swap(p.nodes[d], p.nodes[min_pos]); + } +} + +void prb_simplify(prb_t &p) { +#if defined(__GNUC__) && __GNUC__ >= 4 +/* GCC produces bogus array subscript is above array bounds warning for + * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif + for (int d = 0; d < p.ndims - 1; ++d) { + auto &this_node = p.nodes[d + 0]; + auto &next_node = p.nodes[d + 1]; + const bool fold = false + || next_node.n == (size_t)1 // trivial case, just drop next node + || (true // or real folding if possible + && next_node.is == (ptrdiff_t)this_node.n * this_node.is + && next_node.os == (ptrdiff_t)this_node.n * this_node.os + && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); + if (fold) { + this_node.n *= next_node.n; + for (int j = d + 2; j < p.ndims; ++j) + p.nodes[j - 1] = p.nodes[j]; + --p.ndims; + --d; // make another try + } + } +#if defined(__GNUC__) && __GNUC__ >= 4 +#pragma GCC diagnostic pop +#endif +} + +void prb_node_split(prb_t &p, int dim, size_t n1) { + assert(dim < p.ndims); + assert(p.ndims < max_ndims); + assert(p.nodes[dim].n % n1 == 0); + + p.ndims += 1; + + for (int d = p.ndims; d > dim + 1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[dim + 1].n = p.nodes[dim].n / n1; + p.nodes[dim + 1].is = p.nodes[dim].is * n1; + p.nodes[dim + 1].os = p.nodes[dim].os * n1; + p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; + + p.nodes[dim].n = n1; +} + +void prb_node_swap(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + nstl::swap(p.nodes[d0], p.nodes[d1]); +} + +void prb_node_move(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + node_t node = p.nodes[d0]; + + if (d0 < d1) + for (int d = d0; d < d1; ++d) + p.nodes[d] = p.nodes[d + 1]; + else + for (int d = d0; d > d1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[d1] = node; +} + +void prb_dump(const prb_t &p) { + printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype), + mkldnn_dt2str(p.otype), p.ndims); + for (int d = 0; d < p.ndims; ++d) + printf("[%zu:%td:%td:%td]", + p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); + printf(" off:%zu:%zu\n", p.ioff, p.ooff); +} + +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp new file mode 100644 index 0000000000..08747aa89c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "utils.hpp" + +#ifndef MKLDNN_ENABLE_JIT_PROFILING +#define MKLDNN_ENABLE_JIT_PROFILING 1 +#endif + +#ifndef MKLDNN_ENABLE_JIT_DUMP +#define MKLDNN_ENABLE_JIT_DUMP 1 +#endif + +#if MKLDNN_ENABLE_JIT_PROFILING +#include "jitprofiling/jitprofiling.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +// WARNING: These functions are not thread safe and must be protected by a +// mutex + +void dump_jit_code(const void *code, size_t code_size, const char *code_name) +{ +#if MKLDNN_ENABLE_JIT_DUMP + if (code && jit_dump_enabled()) { + static int counter = 0; +#define MAX_FNAME_LEN 256 + char fname[MAX_FNAME_LEN + 1]; + // TODO (Roma): support prefix for code / linux perf dumps + snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", code_name, + counter); + counter++; + + FILE *fp = fopen(fname, "w+"); + // Failure to dump code is not fatal + if (fp) { + size_t unused = fwrite(code, code_size, 1, fp); + UNUSED(unused); + fclose(fp); + } + } +#undef MAX_FNAME_LEN +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); +#endif +} + +void register_jit_code_vtune(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ +#if MKLDNN_ENABLE_JIT_PROFILING + if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) { + auto jmethod = iJIT_Method_Load(); + jmethod.method_id = iJIT_GetNewMethodID(); // XXX: not thread-safe + jmethod.method_name = (char *)code_name; // XXX: dropping const + jmethod.class_file_name = NULL; + jmethod.source_file_name = (char *)source_file_name; // XXX: dropping const + jmethod.method_load_address = (void *)code; + jmethod.method_size = (unsigned int)code_size; + + iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, + (void*)&jmethod); + } +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ + // The #ifdef guards are required to avoid generating a function that only + // consists of lock and unlock code +#if MKLDNN_ENABLE_JIT_PROFILING || MKLDNN_ENABLE_JIT_DUMP + static std::mutex m; + std::lock_guard guard(m); + + dump_jit_code(code, code_size, code_name); + register_jit_code_vtune(code, code_size, code_name, source_file_name); +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp new file mode 100644 index 0000000000..2f52dba4ac --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SUPPORT_HPP +#define JIT_SUPPORT_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name); + +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD new file mode 100644 index 0000000000..4fd21cea57 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD @@ -0,0 +1,27 @@ +Copyright (c) 2011, Intel Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md new file mode 100644 index 0000000000..fc67c4f134 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md @@ -0,0 +1 @@ +This code is from [Intel SEAPI library](https://github.com/intel/IntelSEAPI) diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h new file mode 100644 index 0000000000..edbf4a15f0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h @@ -0,0 +1,595 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef _ITTNOTIFY_CONFIG_H_ +#define _ITTNOTIFY_CONFIG_H_ + +/** @cond exclude_from_documentation */ +#ifndef ITT_OS_WIN +# define ITT_OS_WIN 1 +#endif /* ITT_OS_WIN */ + +#ifndef ITT_OS_LINUX +# define ITT_OS_LINUX 2 +#endif /* ITT_OS_LINUX */ + +#ifndef ITT_OS_MAC +# define ITT_OS_MAC 3 +#endif /* ITT_OS_MAC */ + +#ifndef ITT_OS_FREEBSD +# define ITT_OS_FREEBSD 4 +#endif /* ITT_OS_FREEBSD */ + +#ifndef ITT_OS +# if defined WIN32 || defined _WIN32 +# define ITT_OS ITT_OS_WIN +# elif defined( __APPLE__ ) && defined( __MACH__ ) +# define ITT_OS ITT_OS_MAC +# elif defined( __FreeBSD__ ) +# define ITT_OS ITT_OS_FREEBSD +# else +# define ITT_OS ITT_OS_LINUX +# endif +#endif /* ITT_OS */ + +#ifndef ITT_PLATFORM_WIN +# define ITT_PLATFORM_WIN 1 +#endif /* ITT_PLATFORM_WIN */ + +#ifndef ITT_PLATFORM_POSIX +# define ITT_PLATFORM_POSIX 2 +#endif /* ITT_PLATFORM_POSIX */ + +#ifndef ITT_PLATFORM_MAC +# define ITT_PLATFORM_MAC 3 +#endif /* ITT_PLATFORM_MAC */ + +#ifndef ITT_PLATFORM_FREEBSD +# define ITT_PLATFORM_FREEBSD 4 +#endif /* ITT_PLATFORM_FREEBSD */ + +#ifndef ITT_PLATFORM +# if ITT_OS==ITT_OS_WIN +# define ITT_PLATFORM ITT_PLATFORM_WIN +# elif ITT_OS==ITT_OS_MAC +# define ITT_PLATFORM ITT_PLATFORM_MAC +# elif ITT_OS==ITT_OS_FREEBSD +# define ITT_PLATFORM ITT_PLATFORM_FREEBSD +# else +# define ITT_PLATFORM ITT_PLATFORM_POSIX +# endif +#endif /* ITT_PLATFORM */ + +#if defined(_UNICODE) && !defined(UNICODE) +#define UNICODE +#endif + +#include +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE || _UNICODE */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef ITTAPI_CDECL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define ITTAPI_CDECL __cdecl +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define ITTAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define ITTAPI_CDECL /* actual only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* ITTAPI_CDECL */ + +#ifndef STDCALL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define STDCALL __stdcall +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define STDCALL __attribute__ ((stdcall)) +# else /* _M_IX86 || __i386__ */ +# define STDCALL /* supported only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* STDCALL */ + +#define ITTAPI ITTAPI_CDECL +#define LIBITTAPI ITTAPI_CDECL + +/* TODO: Temporary for compatibility! */ +#define ITTAPI_CALL ITTAPI_CDECL +#define LIBITTAPI_CALL ITTAPI_CDECL + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +/* use __forceinline (VC++ specific) */ +#define ITT_INLINE __forceinline +#define ITT_INLINE_ATTRIBUTE /* nothing */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/* + * Generally, functions are not inlined unless optimization is specified. + * For functions declared inline, this attribute inlines the function even + * if no optimization level was specified. + */ +#ifdef __STRICT_ANSI__ +#define ITT_INLINE static +#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) +#else /* __STRICT_ANSI__ */ +#define ITT_INLINE static inline +#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) +#endif /* __STRICT_ANSI__ */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/** @endcond */ + +#ifndef ITT_ARCH_IA32 +# define ITT_ARCH_IA32 1 +#endif /* ITT_ARCH_IA32 */ + +#ifndef ITT_ARCH_IA32E +# define ITT_ARCH_IA32E 2 +#endif /* ITT_ARCH_IA32E */ + +#ifndef ITT_ARCH_ARM +# define ITT_ARCH_ARM 4 +#endif /* ITT_ARCH_ARM */ + +#ifndef ITT_ARCH_PPC64 +# define ITT_ARCH_PPC64 5 +#endif /* ITT_ARCH_PPC64 */ + +#ifndef ITT_ARCH +# if defined _M_IX86 || defined __i386__ +# define ITT_ARCH ITT_ARCH_IA32 +# elif defined _M_X64 || defined _M_AMD64 || defined __x86_64__ +# define ITT_ARCH ITT_ARCH_IA32E +# elif defined _M_IA64 || defined __ia64__ +# define ITT_ARCH ITT_ARCH_IA64 +# elif defined _M_ARM || defined __arm__ +# define ITT_ARCH ITT_ARCH_ARM +# elif defined __powerpc64__ +# define ITT_ARCH ITT_ARCH_PPC64 +# endif +#endif + +#ifdef __cplusplus +# define ITT_EXTERN_C extern "C" +# define ITT_EXTERN_C_BEGIN extern "C" { +# define ITT_EXTERN_C_END } +#else +# define ITT_EXTERN_C /* nothing */ +# define ITT_EXTERN_C_BEGIN /* nothing */ +# define ITT_EXTERN_C_END /* nothing */ +#endif /* __cplusplus */ + +#define ITT_TO_STR_AUX(x) #x +#define ITT_TO_STR(x) ITT_TO_STR_AUX(x) + +#define __ITT_BUILD_ASSERT(expr, suffix) do { \ + static char __itt_build_check_##suffix[(expr) ? 1 : -1]; \ + __itt_build_check_##suffix[0] = 0; \ +} while(0) +#define _ITT_BUILD_ASSERT(expr, suffix) __ITT_BUILD_ASSERT((expr), suffix) +#define ITT_BUILD_ASSERT(expr) _ITT_BUILD_ASSERT((expr), __LINE__) + +#define ITT_MAGIC { 0xED, 0xAB, 0xAB, 0xEC, 0x0D, 0xEE, 0xDA, 0x30 } + +/* Replace with snapshot date YYYYMMDD for promotion build. */ +#define API_VERSION_BUILD 20151119 + +#ifndef API_VERSION_NUM +#define API_VERSION_NUM 0.0.0 +#endif /* API_VERSION_NUM */ + +#define API_VERSION "ITT-API-Version " ITT_TO_STR(API_VERSION_NUM) \ + " (" ITT_TO_STR(API_VERSION_BUILD) ")" + +/* OS communication functions */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +typedef HMODULE lib_t; +typedef DWORD TIDT; +typedef CRITICAL_SECTION mutex_t; +#define MUTEX_INITIALIZER { 0 } +#define strong_alias(name, aliasname) /* empty for Windows */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE */ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE 1 /* need for PTHREAD_MUTEX_RECURSIVE */ +#endif /* _GNU_SOURCE */ +#ifndef __USE_UNIX98 +#define __USE_UNIX98 1 /* need for PTHREAD_MUTEX_RECURSIVE, on SLES11.1 with gcc 4.3.4 wherein pthread.h missing dependency on __USE_XOPEN2K8 */ +#endif /*__USE_UNIX98*/ +#include +typedef void* lib_t; +typedef pthread_t TIDT; +typedef pthread_mutex_t mutex_t; +#define MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER +#define _strong_alias(name, aliasname) \ + extern __typeof (name) aliasname __attribute__ ((alias (#name))); +#define strong_alias(name, aliasname) _strong_alias(name, aliasname) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_get_proc(lib, name) GetProcAddress(lib, name) +#define __itt_mutex_init(mutex) InitializeCriticalSection(mutex) +#define __itt_mutex_lock(mutex) EnterCriticalSection(mutex) +#define __itt_mutex_unlock(mutex) LeaveCriticalSection(mutex) +#define __itt_load_lib(name) LoadLibraryA(name) +#define __itt_unload_lib(handle) FreeLibrary(handle) +#define __itt_system_error() (int)GetLastError() +#define __itt_fstrcmp(s1, s2) lstrcmpA(s1, s2) +#define __itt_fstrnlen(s, l) strnlen_s(s, l) +#define __itt_fstrcpyn(s1, b, s2, l) strncpy_s(s1, b, s2, l) +#define __itt_fstrdup(s) _strdup(s) +#define __itt_thread_id() GetCurrentThreadId() +#define __itt_thread_yield() SwitchToThread() +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return InterlockedIncrement(ptr); +} +#endif /* ITT_SIMPLE_INIT */ + +#define DL_SYMBOLS (1) +#define PTHREAD_SYMBOLS (1) + +#else /* ITT_PLATFORM!=ITT_PLATFORM_WIN */ +#define __itt_get_proc(lib, name) dlsym(lib, name) +#define __itt_mutex_init(mutex) {\ + pthread_mutexattr_t mutex_attr; \ + int error_code = pthread_mutexattr_init(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_init", \ + error_code); \ + error_code = pthread_mutexattr_settype(&mutex_attr, \ + PTHREAD_MUTEX_RECURSIVE); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_settype", \ + error_code); \ + error_code = pthread_mutex_init(mutex, &mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutex_init", \ + error_code); \ + error_code = pthread_mutexattr_destroy(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_destroy", \ + error_code); \ +} +#define __itt_mutex_lock(mutex) pthread_mutex_lock(mutex) +#define __itt_mutex_unlock(mutex) pthread_mutex_unlock(mutex) +#define __itt_load_lib(name) dlopen(name, RTLD_LAZY) +#define __itt_unload_lib(handle) dlclose(handle) +#define __itt_system_error() errno +#define __itt_fstrcmp(s1, s2) strcmp(s1, s2) + +/* makes customer code define safe APIs for SDL_STRNLEN_S and SDL_STRNCPY_S */ +#ifdef SDL_STRNLEN_S +#define __itt_fstrnlen(s, l) SDL_STRNLEN_S(s, l) +#else +#define __itt_fstrnlen(s, l) strlen(s) +#endif /* SDL_STRNLEN_S */ +#ifdef SDL_STRNCPY_S +#define __itt_fstrcpyn(s1, b, s2, l) SDL_STRNCPY_S(s1, b, s2, l) +#else +#define __itt_fstrcpyn(s1, b, s2, l) strncpy(s1, s2, l) +#endif /* SDL_STRNCPY_S */ + +#define __itt_fstrdup(s) strdup(s) +#define __itt_thread_id() pthread_self() +#define __itt_thread_yield() sched_yield() +#if ITT_ARCH==ITT_ARCH_IA64 +#ifdef __INTEL_COMPILER +#define __TBB_machine_fetchadd4(addr, val) __fetchadd4_acq((void *)addr, val) +#else /* __INTEL_COMPILER */ +/* TODO: Add Support for not Intel compilers for IA-64 architecture */ +#endif /* __INTEL_COMPILER */ +#elif ITT_ARCH==ITT_ARCH_IA32 || ITT_ARCH==ITT_ARCH_IA32E /* ITT_ARCH!=ITT_ARCH_IA64 */ +ITT_INLINE long +__TBB_machine_fetchadd4(volatile void* ptr, long addend) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __TBB_machine_fetchadd4(volatile void* ptr, long addend) +{ + long result; + __asm__ __volatile__("lock\nxadd %0,%1" + : "=r"(result),"=m"(*(int*)ptr) + : "0"(addend), "m"(*(int*)ptr) + : "memory"); + return result; +} +#elif ITT_ARCH==ITT_ARCH_ARM || ITT_ARCH==ITT_ARCH_PPC64 +#define __TBB_machine_fetchadd4(addr, val) __sync_fetch_and_add(addr, val) +#endif /* ITT_ARCH==ITT_ARCH_IA64 */ +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return __TBB_machine_fetchadd4(ptr, 1) + 1L; +} +#endif /* ITT_SIMPLE_INIT */ + +void* dlopen(const char*, int) __attribute__((weak)); +void* dlsym(void*, const char*) __attribute__((weak)); +int dlclose(void*) __attribute__((weak)); +#define DL_SYMBOLS (dlopen && dlsym && dlclose) + +int pthread_mutex_init(pthread_mutex_t*, const pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutex_lock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_unlock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_destroy(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutexattr_init(pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutexattr_settype(pthread_mutexattr_t*, int) __attribute__((weak)); +int pthread_mutexattr_destroy(pthread_mutexattr_t*) __attribute__((weak)); +pthread_t pthread_self(void) __attribute__((weak)); +#define PTHREAD_SYMBOLS (pthread_mutex_init && pthread_mutex_lock && pthread_mutex_unlock && pthread_mutex_destroy && pthread_mutexattr_init && pthread_mutexattr_settype && pthread_mutexattr_destroy && pthread_self) + +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +typedef enum { + __itt_collection_normal = 0, + __itt_collection_paused = 1 +} __itt_collection_state; + +typedef enum { + __itt_thread_normal = 0, + __itt_thread_ignored = 1 +} __itt_thread_state; + +#pragma pack(push, 8) + +typedef struct ___itt_thread_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + TIDT tid; + __itt_thread_state state; /*!< Thread state (paused or normal) */ + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct ___itt_thread_info* next; +} __itt_thread_info; + +#include "ittnotify_types.h" /* For __itt_group_id definition */ + +typedef struct ___itt_api_info_20101001 +{ + const char* name; + void** func_ptr; + void* init_func; + __itt_group_id group; +} __itt_api_info_20101001; + +typedef struct ___itt_api_info +{ + const char* name; + void** func_ptr; + void* init_func; + void* null_func; + __itt_group_id group; +} __itt_api_info; + +typedef struct __itt_counter_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + const char* domainA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* domainW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* domainW; +#endif /* UNICODE || _UNICODE */ + int type; + long index; + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct __itt_counter_info* next; +} __itt_counter_info_t; + +struct ___itt_domain; +struct ___itt_string_handle; + +typedef struct ___itt_global +{ + unsigned char magic[8]; + unsigned long version_major; + unsigned long version_minor; + unsigned long version_build; + volatile long api_initialized; + volatile long mutex_initialized; + volatile long atomic_counter; + mutex_t mutex; + lib_t lib; + void* error_handler; + const char** dll_path_ptr; + __itt_api_info* api_list_ptr; + struct ___itt_global* next; + /* Joinable structures below */ + __itt_thread_info* thread_list; + struct ___itt_domain* domain_list; + struct ___itt_string_handle* string_list; + __itt_collection_state state; + __itt_counter_info_t* counter_list; +} __itt_global; + +#pragma pack(pop) + +#define NEW_THREAD_INFO_W(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = NULL; \ + h->nameW = n ? _wcsdup(n) : NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_THREAD_INFO_A(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = n ? __itt_fstrdup(n) : NULL; \ + h->nameW = NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_W(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_A(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_W(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = NULL; \ + h->strW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_A(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = name ? __itt_fstrdup(name) : NULL; \ + h->strW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_W(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->domainA = NULL; \ + h->domainW = name ? _wcsdup(domain) : NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_A(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->domainA = domain ? __itt_fstrdup(domain) : NULL; \ + h->domainW = NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#endif /* _ITTNOTIFY_CONFIG_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h new file mode 100644 index 0000000000..99fbc24054 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h @@ -0,0 +1,94 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _ITTNOTIFY_TYPES_H_ +#define _ITTNOTIFY_TYPES_H_ + +typedef enum ___itt_group_id +{ + __itt_group_none = 0, + __itt_group_legacy = 1<<0, + __itt_group_control = 1<<1, + __itt_group_thread = 1<<2, + __itt_group_mark = 1<<3, + __itt_group_sync = 1<<4, + __itt_group_fsync = 1<<5, + __itt_group_jit = 1<<6, + __itt_group_model = 1<<7, + __itt_group_splitter_min = 1<<7, + __itt_group_counter = 1<<8, + __itt_group_frame = 1<<9, + __itt_group_stitch = 1<<10, + __itt_group_heap = 1<<11, + __itt_group_splitter_max = 1<<12, + __itt_group_structure = 1<<12, + __itt_group_suppress = 1<<13, + __itt_group_arrays = 1<<14, + __itt_group_all = -1 +} __itt_group_id; + +#pragma pack(push, 8) + +typedef struct ___itt_group_list +{ + __itt_group_id id; + const char* name; +} __itt_group_list; + +#pragma pack(pop) + +#define ITT_GROUP_LIST(varname) \ + static __itt_group_list varname[] = { \ + { __itt_group_all, "all" }, \ + { __itt_group_control, "control" }, \ + { __itt_group_thread, "thread" }, \ + { __itt_group_mark, "mark" }, \ + { __itt_group_sync, "sync" }, \ + { __itt_group_fsync, "fsync" }, \ + { __itt_group_jit, "jit" }, \ + { __itt_group_model, "model" }, \ + { __itt_group_counter, "counter" }, \ + { __itt_group_frame, "frame" }, \ + { __itt_group_stitch, "stitch" }, \ + { __itt_group_heap, "heap" }, \ + { __itt_group_structure, "structure" }, \ + { __itt_group_suppress, "suppress" }, \ + { __itt_group_arrays, "arrays" }, \ + { __itt_group_none, NULL } \ + } + +#endif /* _ITTNOTIFY_TYPES_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c new file mode 100644 index 0000000000..15f4b9929b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c @@ -0,0 +1,293 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "ittnotify_config.h" + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM != ITT_PLATFORM_MAC && ITT_PLATFORM != ITT_PLATFORM_FREEBSD +#include +#endif +#include + +#include "jitprofiling.h" + +static const char rcsid[] = "\n@(#) $Revision: 471937 $\n"; + +#define DLL_ENVIRONMENT_VAR "VS_PROFILER" + +#ifndef NEW_DLL_ENVIRONMENT_VAR +#if ITT_ARCH==ITT_ARCH_IA32 +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER32" +#else +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER64" +#endif +#endif /* NEW_DLL_ENVIRONMENT_VAR */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define DEFAULT_DLLNAME "JitPI.dll" +HINSTANCE m_libHandle = NULL; +#elif ITT_PLATFORM==ITT_PLATFORM_MAC +#define DEFAULT_DLLNAME "libJitPI.dylib" +void* m_libHandle = NULL; +#else +#define DEFAULT_DLLNAME "libJitPI.so" +void* m_libHandle = NULL; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/* default location of JIT profiling agent on Android */ +#define ANDROID_JIT_AGENT_PATH "/data/intel/libittnotify.so" + +/* the function pointers */ +typedef unsigned int(JITAPI *TPInitialize)(void); +static TPInitialize FUNC_Initialize=NULL; + +typedef unsigned int(JITAPI *TPNotify)(unsigned int, void*); +static TPNotify FUNC_NotifyEvent=NULL; + +static iJIT_IsProfilingActiveFlags executionMode = iJIT_NOTHING_RUNNING; + +/* end collector dll part. */ + +/* loadiJIT_Funcs() : this function is called just in the beginning + * and is responsible to load the functions from BistroJavaCollector.dll + * result: + * on success: the functions loads, iJIT_DLL_is_missing=0, return value = 1 + * on failure: the functions are NULL, iJIT_DLL_is_missing=1, return value = 0 + */ +static int loadiJIT_Funcs(void); + +/* global representing whether the collector can't be loaded */ +static int iJIT_DLL_is_missing = 0; + +ITT_EXTERN_C int JITAPI +iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData) +{ + int ReturnValue = 0; + + /* initialization part - the collector has not been loaded yet. */ + if (!FUNC_NotifyEvent) + { + if (iJIT_DLL_is_missing) + return 0; + + if (!loadiJIT_Funcs()) + return 0; + } + + if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED || + event_type == iJVM_EVENT_TYPE_METHOD_UPDATE) + { + if (((piJIT_Method_Load)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2) + { + if (((piJIT_Method_Load_V2)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3) + { + if (((piJIT_Method_Load_V3)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED) + { + if (((piJIT_Method_Inline_Load)EventSpecificData)->method_id == 0 || + ((piJIT_Method_Inline_Load)EventSpecificData)->parent_method_id == 0) + return 0; + } + + ReturnValue = (int)FUNC_NotifyEvent(event_type, EventSpecificData); + + return ReturnValue; +} + +ITT_EXTERN_C iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive() +{ + if (!iJIT_DLL_is_missing) + { + loadiJIT_Funcs(); + } + + return executionMode; +} + +/* This function loads the collector dll and the relevant functions. + * on success: all functions load, iJIT_DLL_is_missing = 0, return value = 1 + * on failure: all functions are NULL, iJIT_DLL_is_missing = 1, return value = 0 + */ +static int loadiJIT_Funcs() +{ + static int bDllWasLoaded = 0; + char *dllName = (char*)rcsid; /* !! Just to avoid unused code elimination */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + DWORD dNameLength = 0; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if(bDllWasLoaded) + { + /* dll was already loaded, no need to do it for the second time */ + return 1; + } + + /* Assumes that the DLL will not be found */ + iJIT_DLL_is_missing = 1; + FUNC_NotifyEvent = NULL; + + if (m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FreeLibrary(m_libHandle); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dlclose(m_libHandle); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = NULL; + } + + /* Try to get the dll name from the environment */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + dNameLength = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryExA(dllName, + NULL, LOAD_WITH_ALTERED_SEARCH_PATH); + } + free(dllName); + } + } else { + /* Try to use old VS_PROFILER variable */ + dNameLength = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryA(dllName); + } + free(dllName); + } + } + } +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dllName = getenv(NEW_DLL_ENVIRONMENT_VAR); + if (!dllName) + dllName = getenv(DLL_ENVIRONMENT_VAR); +#if defined(__ANDROID__) || defined(ANDROID) + if (!dllName) + dllName = ANDROID_JIT_AGENT_PATH; +#endif + if (dllName) + { + /* Try to load the dll from the PATH... */ + m_libHandle = dlopen(dllName, RTLD_LAZY); + } +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if (!m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + m_libHandle = LoadLibraryA(DEFAULT_DLLNAME); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = dlopen(DEFAULT_DLLNAME, RTLD_LAZY); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + } + + /* if the dll wasn't loaded - exit. */ + if (!m_libHandle) + { + iJIT_DLL_is_missing = 1; /* don't try to initialize + * JIT agent the second time + */ + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_NotifyEvent = (TPNotify)GetProcAddress(m_libHandle, "NotifyEvent"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_NotifyEvent = (TPNotify)dlsym(m_libHandle, "NotifyEvent"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_NotifyEvent) + { + FUNC_Initialize = NULL; + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_Initialize = (TPInitialize)GetProcAddress(m_libHandle, "Initialize"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_Initialize = (TPInitialize)dlsym(m_libHandle, "Initialize"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_Initialize) + { + FUNC_NotifyEvent = NULL; + return 0; + } + + executionMode = (iJIT_IsProfilingActiveFlags)FUNC_Initialize(); + + bDllWasLoaded = 1; + iJIT_DLL_is_missing = 0; /* DLL is ok. */ + + return 1; +} + +ITT_EXTERN_C unsigned int JITAPI iJIT_GetNewMethodID() +{ + static unsigned int methodID = 1; + + if (methodID == 0) + return 0; /* ERROR : this is not a valid value */ + + return methodID++; +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h new file mode 100644 index 0000000000..bf0489b1a1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h @@ -0,0 +1,673 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef __JITPROFILING_H__ +#define __JITPROFILING_H__ + +/** + * @brief JIT Profiling APIs + * + * The JIT Profiling API is used to report information about just-in-time + * generated code that can be used by performance tools. The user inserts + * calls in the code generator to report information before JIT-compiled + * code goes to execution. This information is collected at runtime and used + * by tools like Intel(R) VTune(TM) Amplifier to display performance metrics + * associated with JIT-compiled code. + * + * These APIs can be used to\n + * - **Profile trace-based and method-based JIT-compiled + * code**. Some examples of environments that you can profile with these APIs: + * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM) + * software technology, Java/.NET managed execution environments, and custom + * ISV JIT engines. + * @code + * #include + * + * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) { + * return; + * } + * + * iJIT_Method_Load jmethod = {0}; + * jmethod.method_id = iJIT_GetNewMethodID(); + * jmethod.method_name = "method_name"; + * jmethod.class_file_name = "class_name"; + * jmethod.source_file_name = "source_file_name"; + * jmethod.method_load_address = code_addr; + * jmethod.method_size = code_size; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL); + * @endcode + * + * * Expected behavior: + * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and its + * memory region is treated as unloaded. VTune Amplifier displays the metrics + * collected by the method until it is overwritten. + * * If supplied line number information contains multiple source lines for + * the same assembly instruction (code location), then VTune Amplifier picks up + * the first line number. + * * Dynamically generated code can be associated with a module name. + * Use the iJIT_Method_Load_V2 structure.\n + * Clarification of some cases: + * * If you register a function with the same method ID multiple times, + * specifying different module names, then the VTune Amplifier picks up + * the module name registered first. If you want to distinguish the same + * function between different JIT engines, supply different method IDs for + * each function. Other symbolic information (for example, source file) + * can be identical. + * + * - **Analyze split functions** (multiple joint or disjoint code regions + * belonging to the same function) **including re-JIT** + * with potential overlapping of code regions in time, which is common in + * resource-limited environments. + * @code + * #include + * + * unsigned int method_id = iJIT_GetNewMethodID(); + * + * iJIT_Method_Load a = {0}; + * a.method_id = method_id; + * a.method_load_address = 0x100; + * a.method_size = 0x20; + * + * iJIT_Method_Load b = {0}; + * b.method_id = method_id; + * b.method_load_address = 0x200; + * b.method_size = 0x30; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b); + * @endcode + * + * * Expected behaviour: + * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and + * its memory region is treated as unloaded. + * * All code regions reported with the same method ID are considered as + * belonging to the same method. Symbolic information (method name, + * source file name) will be taken from the first notification, and all + * subsequent notifications with the same method ID will be processed + * only for line number table information. So, the VTune Amplifier will map + * samples to a source line using the line number table from the current + * notification while taking the source file name from the very first one.\n + * Clarification of some cases:\n + * * If you register a second code region with a different source file + * name and the same method ID, then this information will be saved and + * will not be considered as an extension of the first code region, but + * VTune Amplifier will use the source file of the first code region and map + * performance metrics incorrectly. + * * If you register a second code region with the same source file as + * for the first region and the same method ID, then the source file will be + * discarded but VTune Amplifier will map metrics to the source file correctly. + * * If you register a second code region with a null source file and + * the same method ID, then provided line number info will be associated + * with the source file of the first code region. + * + * - **Explore inline functions** including multi-level hierarchy of + * nested inline methods which shows how performance metrics are distributed through them. + * @code + * #include + * + * // method_id parent_id + * // [-- c --] 3000 2000 + * // [---- d -----] 2001 1000 + * // [---- b ----] 2000 1000 + * // [------------ a ----------------] 1000 n/a + * + * iJIT_Method_Load a = {0}; + * a.method_id = 1000; + * + * iJIT_Method_Inline_Load b = {0}; + * b.method_id = 2000; + * b.parent_method_id = 1000; + * + * iJIT_Method_Inline_Load c = {0}; + * c.method_id = 3000; + * c.parent_method_id = 2000; + * + * iJIT_Method_Inline_Load d = {0}; + * d.method_id = 2001; + * d.parent_method_id = 1000; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d); + * @endcode + * + * * Requirements: + * * Each inline (iJIT_Method_Inline_Load) method should be associated + * with two method IDs: one for itself; one for its immediate parent. + * * Address regions of inline methods of the same parent method cannot + * overlap each other. + * * Execution of the parent method must not be started until it and all + * its inline methods are reported. + * * Expected behaviour: + * * In case of nested inline methods an order of + * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important. + * * If any event overwrites either inline method or top parent method, + * then the parent, including inline methods, becomes invalid and its memory + * region is treated as unloaded. + * + * **Life time of allocated data**\n + * The client sends an event notification to the agent with event-specific + * data, which is a structure. The pointers in the structure refer to memory + * allocated by the client, which responsible for releasing it. The pointers are + * used by the iJIT_NotifyEvent method to copy client's data in a trace file, + * and they are not used after the iJIT_NotifyEvent method returns. + */ + +/** + * @defgroup jitapi JIT Profiling + * @ingroup internal + * @{ + */ + +/** + * @brief Enumerator for the types of notifications + */ +typedef enum iJIT_jvm_event +{ + iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent. + * Use NULL for event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load as event + * data. */ +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic + * code is being unloaded from memory. + * Use iJIT_Method_Load as event data.*/ +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for + * a previously reported dynamic code. + * The previous content will be invalidated + * starting from the time of the notification. + * Use iJIT_Method_Load as event data but + * required fields are following: + * - method_id identify the code to update. + * - method_load_address specify start address + * within identified code range + * where update should be started. + * - method_size specify length of updated code + * range. */ + + + iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic + * code is JIT compiled and loaded + * into memory by the JIT engine, + * but before the parent code region + * starts executing. + * Use iJIT_Method_Inline_Load as event data.*/ + +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UPDATE_V2, +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V2 as event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V3 as event data. */ +} iJIT_JVM_EVENT; + +/** + * @brief Enumerator for the agent's mode + */ +typedef enum _iJIT_IsProfilingActiveFlags +{ + iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running; + * iJIT_NotifyEvent calls will + * not be processed. */ + iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and + * ready to process notifications. */ +} iJIT_IsProfilingActiveFlags; + +/** + * @brief Description of a single entry in the line number information of a code region. + * @details A table of line number entries gives information about how the reported code region + * is mapped to source file. + * Intel(R) VTune(TM) Amplifier uses line number information to attribute + * the samples (virtual address) to a line number. \n + * It is acceptable to report different code addresses for the same source line: + * @code + * Offset LineNumber + * 1 2 + * 12 4 + * 15 2 + * 18 1 + * 21 30 + * + * VTune Amplifier constructs the following table using the client data + * + * Code subrange Line number + * 0-1 2 + * 1-12 4 + * 12-15 2 + * 15-18 1 + * 18-21 30 + * @endcode + */ +typedef struct _LineNumberInfo +{ + unsigned int Offset; /**<\brief Offset from the begining of the code region. */ + unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */ + +} *pLineNumberInfo, LineNumberInfo; + +/** + * @brief Enumerator for the code architecture. + */ +typedef enum _iJIT_CodeArchitecture +{ + iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */ + + iJIT_CA_32, /**<\brief 32-bit machine code. */ + + iJIT_CA_64 /**<\brief 64-bit machine code. */ + +} iJIT_CodeArchitecture; + +#pragma pack(push, 8) + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, data provided with + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table.0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + unsigned int class_id; /**<\brief This field is obsolete. */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Load, iJIT_Method_Load; + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load_V2 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V2 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + The module name can be useful for distinguishing among + different JIT engines. VTune Amplifier will display + reported methods grouped by specific module. */ + +} *piJIT_Method_Load_V2, iJIT_Method_Load_V2; + +/** + * @brief Description of a JIT-compiled method + * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2 + * with a newly introduced 'arch' field that specifies architecture of the code region. + * When you use the iJIT_Method_Load_V3 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V3 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise they are + * treated as regions of different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Cannot be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + * The module name can be useful for distinguishing among + * different JIT engines. VTune Amplifier will display + * reported methods grouped by specific module. */ + + iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region. + * By default, it is the same as the process + * architecture that is calling it. + * For example, you can use it if your 32-bit JIT + * engine generates 64-bit code. + * + * If JIT engine reports both 32-bit and 64-bit types + * of methods then VTune Amplifier splits the methods + * with the same module name but with different + * architectures in two different modules. VTune Amplifier + * modifies the original name provided with a 64-bit method + * version by ending it with '(64)' */ + +} *piJIT_Method_Load_V3, iJIT_Method_Load_V3; + +/** + * @brief Description of an inline JIT-compiled method + * @details When you use the_iJIT_Method_Inline_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Inline_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID. + * Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /** <\brief The virtual address on which the method + * is inlined. If NULL, then data provided with + * the event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load; + +/** @cond exclude_from_documentation */ +/** + * @brief Description of a segment type + * @details Use the segment type to specify a type of data supplied + * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to + * a certain code trace. + */ +typedef enum _iJIT_SegmentType +{ + iJIT_CT_UNKNOWN = 0, + + iJIT_CT_CODE, /**<\brief Executable code. */ + + iJIT_CT_DATA, /**<\brief Data (not executable code). + * VTune Amplifier uses the format string + * (see iJIT_Method_Update) to represent + * this data in the VTune Amplifier GUI */ + + iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace. + * Can be used for the following + * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events, + * if the type of the previously reported segment + * type is the same. */ + iJIT_CT_EOF +} iJIT_SegmentType; + +/** + * @brief Description of a dynamic update of the content within JIT-compiled method + * @details The JIT engine may generate the methods that are updated at runtime + * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update + * structure to describe the update of the content within a JIT-compiled method, + * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it. + * + * On the first Update event, VTune Amplifier copies the original code range reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and + * adds the modified range to the original method. For next update events, VTune Amplifier + * does the same but it uses the latest modified version of a code region for update. + * Eventually, VTune Amplifier GUI displays multiple code ranges for the method reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event. + * Notes: + * - Multiple update events with different types for the same trace are allowed + * but they must be reported for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [code] Ignored + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + * - The types of previously reported events can be changed but they must be reported + * for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + */ + +typedef struct _iJIT_Method_Update +{ + void* load_address; /**<\brief Start address of the update within a method */ + + unsigned int size; /**<\brief The update size */ + + iJIT_SegmentType type; /**<\brief Type of the update */ + + const char* data_format; /**<\brief C string that contains a format string + * that follows the same specifications as format in printf. + * The format string is used for iJIT_CT_CODE only + * and cannot be NULL. + * Format can be changed on the fly. */ +} *piJIT_Method_Update, iJIT_Method_Update; + +/** @endcond */ + +#pragma pack(pop) + +/** @cond exclude_from_documentation */ +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#ifndef JITAPI_CDECL +# if defined WIN32 || defined _WIN32 +# define JITAPI_CDECL __cdecl +# else /* defined WIN32 || defined _WIN32 */ +# if defined _M_IX86 || defined __i386__ +# define JITAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define JITAPI_CDECL /* actual only on x86_64 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* defined WIN32 || defined _WIN32 */ +#endif /* JITAPI_CDECL */ + +#define JITAPI JITAPI_CDECL +/** @endcond */ + +/** + * @brief Generates a new unique method ID. + * + * You must use this API to obtain unique and valid method IDs for methods or + * traces reported to the agent if you don't have your own mechanism to generate + * unique method IDs. + * + * @return a new unique method ID. When out of unique method IDs, this API + * returns 0, which is not an accepted value. + */ +unsigned int JITAPI iJIT_GetNewMethodID(void); + +/** + * @brief Returns the current mode of the agent. + * + * @return iJIT_SAMPLING_ON, indicating that agent is running, or + * iJIT_NOTHING_RUNNING if no agent is running. + */ +iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void); + +/** + * @brief Reports infomation about JIT-compiled code to the agent. + * + * The reported information is used to attribute samples obtained from any + * Intel(R) VTune(TM) Amplifier collector. This API needs to be called + * after JIT compilation and before the first entry into the JIT-compiled + * code. + * + * @param[in] event_type - type of the data sent to the agent + * @param[in] EventSpecificData - pointer to event-specific data + * + * @returns 1 on success, otherwise 0. + */ +int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData); + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +/** @endcond */ + +/** @} jitapi group */ + +#endif /* __JITPROFILING_H__ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp new file mode 100644 index 0000000000..ef4c42bacf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp @@ -0,0 +1,317 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nchw_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void nchw_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t ws_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[ws_offset] = value; + } else + reinterpret_cast(ws)[ws_offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + auto s = src[src_offset]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH + : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + d[0] += src[src_offset]; + } + } + } + + d[0] = math::out_round((float)d[0] / num_summands); + }; + + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = nstl::numeric_limits::lowest(); + set_ws(mb, c, od, oh, ow, 0); + ker_max(d, mb, c, od, oh, ow); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = 0; + ker_avg(d, mb, c, od, oh, ow); + }); + } +} + +template +void nchw_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int mb, int c) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW; + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_offset++] = 0; + } + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto b_c = ws_d.blocking_desc().inner_nblks == 0 + ? 1 : ws_d.blocking_desc().inner_blks[0]; + auto ws_offset = is_3d + ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c + : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c; + + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + size_t diff_src_offset = + (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + size_t num_summands = (alg == pooling_avg_include_padding) + ? (size_t)KW*KH*KD + : (size_t)(id_end - id_start)*(ih_end - ih_start) + *(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0] / num_summands; + } + } + } + }; + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_max(d, mb, c, od, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_avg(d, mb, c, od, oh, ow); + } + } + } + }); + } +} + +template struct nchw_pooling_fwd_t; +template struct nchw_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp new file mode 100644 index 0000000000..bbdd04f6b9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp @@ -0,0 +1,147 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCHW_POOLING_HPP +#define CPU_NCHW_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct nchw_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nchw_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct nchw_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + bool ws_ok = true + && hint_fwd_pd_ + && hint_fwd_pd_->workspace_md(); + if (!ws_ok) + return status::unimplemented; + + const auto &ws_blk = + hint_fwd_pd_->workspace_md()->format_desc.blocking; + ws_ok = ws_ok + && ws_blk.inner_nblks < 1 + && IMPLICATION(ws_blk.inner_nblks == 1, + ws_blk.inner_idxs[0] == 1); + if (!ws_ok) + return status::unimplemented; + + ws_md_ = *hint_fwd_pd_->workspace_md(); + } + + return status::success; + } + }; + + nchw_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp new file mode 100644 index 0000000000..c0e93fefe4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp @@ -0,0 +1,382 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "ncsp_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void ncsp_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool calculate_stats = !pd()->stats_is_src(); + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = scratchpad.get(key_bnorm_tmp_mean); + variance = scratchpad.get(key_bnorm_tmp_var); + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t N = pd()->MB(); + dim_t C = pd()->C(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off; + + if (calculate_stats) { + data_t *mean_blk = mean + C_off; + data_t *variance_blk = variance + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = (c + C_off) * SP; + data_t sum = 0; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + sum += src[off + n * C * SP + sp]; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + mean_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + mean_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + mean_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sum = 0.; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + data_t m = src[off * SP + n * C * SP + sp] + - mean[off]; + sum += m * m; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + variance_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + variance_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + variance_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + } + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sqrt_variance + = static_cast(sqrtf(variance[off] + eps)); + data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + off] : 0; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + size_t d_off = off * SP + n * C * SP + sp; + data_t bn_res + = sm * (src[d_off] - mean[off]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void ncsp_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + if (diff_scaleshift == nullptr) + diff_scaleshift = scratchpad.get(key_bnorm_tmp_diff_ss); + + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t C = pd()->C(), N = pd()->MB(); + const bool use_scaleshift = pd()->use_scaleshift(); + const float eps = pd()->desc()->batch_norm_epsilon; + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = 2 * N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off; + + data_t *diff_gamma_blk = diff_scaleshift + C_off; + data_t *diff_beta_blk = diff_scaleshift + C + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t diff_gamma = 0.0, diff_beta = 0.0; + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + diff_gamma += (src[d_off] - v_mean) * dd; + diff_beta += dd; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = diff_gamma; + ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter + + SP_N_ithr * C_blks_per_iter + c] = diff_beta; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + data_t sqrt_variance = static_cast( + 1.0f / sqrtf(variance[c + C_off] + eps)); + diff_gamma_blk[c] = 0.; + diff_beta_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) { + diff_gamma_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + diff_beta_blk[c] += ws_reduce[ws_iter_off + + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter + + c]; + } + diff_gamma_blk[c] *= sqrt_variance; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t gamma = use_scaleshift ? scaleshift[off] : 1; + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[off] + eps)); + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_blk[c] / (SP * N) + + (src[d_off] - v_mean) * diff_gamma_blk[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp new file mode 100644 index 0000000000..97ca3b003f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp @@ -0,0 +1,160 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP +#define CPU_NCSP_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ncsp_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + using cpu_batch_normalization_fwd_pd_t::cpu_batch_normalization_fwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + using namespace format_tag; + + bool ok = true + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * C() * mkldnn_get_max_threads()); + + if (!is_training()) { + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C()); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C()); + } + } + } + }; + + typedef typename prec_traits::type data_t; + + ncsp_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + using cpu_batch_normalization_bwd_pd_t::cpu_batch_normalization_bwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace format_tag; + + bool ok = true + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && memory_desc_matches_one_of_tag(*diff_src_md(), ncdhw, nchw, nc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward)) + scratchpad.book(key_bnorm_tmp_diff_ss, + sizeof(data_t) * 2 * C()); + } + }; + + typedef typename prec_traits::type data_t; + + ncsp_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp new file mode 100644 index 0000000000..38cfb28dce --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp @@ -0,0 +1,392 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nhwc_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define MEM_D(name) name##_d + +#define DECLARE_READ_STRIDES(name) \ + const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \ + const size_t name##_d_stride = (!is_3d) \ + ? 0 \ + : MEM_D(name).blocking_desc().strides[2]; \ + const size_t name##_h_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[2] \ + : MEM_D(name).blocking_desc().strides[3]; \ + const size_t name##_w_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[3] \ + : MEM_D(name).blocking_desc().strides[4]; + +namespace nhwc_pooling { + size_t strided_offset(const int _n, const size_t _sn, + const int _d, const size_t _sd, + const int _h, const size_t _sh, + const int _w, const size_t _sw) + { + return _n * _sn + + _d * _sd + + _h * _sh + + _w * _sw; + } +} + +template +void nhwc_pooling_fwd_t::array_div_by_const(const int n, + const data_t *src, const size_t num, data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + float ftmp = (float)src[i]; + ftmp = ftmp / num; + dst[i] = math::out_round(ftmp); + } +} + +template +void nhwc_pooling_fwd_t::array_add(const int n, const data_t *src, + data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + dst[i] += src[i]; + } +} + +template +void nhwc_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + using namespace nhwc_pooling; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper MEM_D(src)(pd()->src_md()); + const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + DECLARE_READ_STRIDES(src); + DECLARE_READ_STRIDES(dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + parallel_nd(MB, OD, OH, OW, + [&](int mb, int od, int oh, int ow) { + size_t dst_offset_init = strided_offset(mb, dst_n_stride, + od, dst_d_stride, + oh, dst_h_stride, + ow, dst_w_stride); + if (alg == pooling_max) { + size_t ws_offset_init = 0; + if (ws) + { + DECLARE_READ_STRIDES(ws); + ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + } + // Note: GCC 4.8.5 won't vectorize below + // simple loops unless they are singled out + // into separate helper routines: + // array_nhwc_initialize, array_nhwc_max + if (!ws) + array_nhwc_initialize(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + else + array_nhwc_initialize(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + + + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) + continue; + if (ih < 0 || ih >= IH) + continue; + if (iw < 0 || iw >= IW) + continue; + + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + + if (!ws) + array_nhwc_max(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + else + array_nhwc_max(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + } + } else { + // pooling_avg + auto d = dst + dst_offset_init; + + utils::array_set(d, 0, OC); + + auto id_start = apply_offset(od * SD, padF); + auto ih_start = apply_offset(oh * SH, padT); + auto iw_start = apply_offset(ow * SW, padL); + auto id_end = nstl::min(od * SD - padF + KD, ID); + auto ih_end = nstl::min(oh * SH - padT + KH, IH); + auto iw_end = nstl::min(ow * SW - padL + KW, IW); + + // it is cheaper to actually count this in a loop + // as the typical kernel is small + size_t num_summands = 0; + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + auto s = src + src_offset_init; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_add(OC, s, d); + + num_summands++; + } + + num_summands = (alg == pooling_avg_include_padding) ? + KW * KH * KD : num_summands; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_div_by_const(OC, d, num_summands, d); + } + }); +} + +template +void nhwc_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace nhwc_pooling; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md()); + const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int OC = pd()->C(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + auto alg = pd()->desc()->alg_kind; + + DECLARE_READ_STRIDES(diff_src); + DECLARE_READ_STRIDES(diff_dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + const int MB = pd()->MB(); + + parallel_nd(MB, ID, IH, IW, + [&](int mb, int id, int ih, int iw) { + size_t src_offset_init = strided_offset(mb, diff_src_n_stride, + id, diff_src_d_stride, + ih, diff_src_h_stride, + iw, diff_src_w_stride); + + // check if kernel windows are disjoint, in this case there's no + // update needed and we just write there once, no initialization + // required. + if (!(KD == SD && KH == SH && KW == SW)) + for (int oc = 0; oc < OC; ++oc) + diff_src[src_offset_init + oc] = data_type_t(0); + + // Find out which output cells may correspond to current + // input position. Current input postition divided by + // stride, with integer divide rounding down, is the + // right-most output. + // Left-most output may be computed if we decrement input + // by (kernel_size - 1) and then do the same division by + // stride. + int od_left = nstl::max((id + padF - KD + 1) / SD, 0); + int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0); + int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0); + // Notice +1 here to preserve the C loop "less than" + // condition for continuing the for loop. + int od_right = nstl::min((id + padF) / SD + 1 , OD); + int oh_right = nstl::min((ih + padT) / SH + 1 , OH); + int ow_right = nstl::min((iw + padL) / SW + 1 , OW); + + for (int od = od_left; od < od_right; ++od) + for (int oh = oh_left; oh < oh_right; ++oh) + for (int ow = ow_left; ow < ow_right; ++ow) { + const int kd = id - od*SD + padF; + const int kh = ih - oh*SH + padT; + const int kw = iw - ow*SW + padL; + + if (kd < 0 || kd >= KD) + continue; + if (kh < 0 || kh >= KH) + continue; + if (kw < 0 || kw >= KW) + continue; + + size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, + od, diff_dst_d_stride, + oh, diff_dst_h_stride, + ow, diff_dst_w_stride); + + if (alg == pooling_max) { + DECLARE_READ_STRIDES(ws); + size_t ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + const int index = kd * KH * KW + kh * KW + kw; + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const int index_from_ws = + (MEM_D(ws).data_type() == data_type::u8) + ? (int)ws[ws_offset_init + oc] + : ((int *)ws)[ws_offset_init + oc]; + + const data_t d = diff_dst[dst_offset_init + oc]; + + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += + (index_from_ws == index) + ? d + : data_type_t(0); + else + diff_src[src_offset_init + oc] = + (index_from_ws == index) + ? d + : data_type_t(0); + } + } else { + // pooling_avg + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) + ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const data_t d = diff_dst[dst_offset_init + oc]; + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += d / num_summands; + else + diff_src[src_offset_init + oc] = d / num_summands; + } + } + } + }); +} + +template struct nhwc_pooling_fwd_t; +template struct nhwc_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp new file mode 100644 index 0000000000..7e33b6869f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp @@ -0,0 +1,210 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NHWC_POOLING_HPP +#define CPU_NHWC_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace nhwc_pooling { +size_t strided_offset(const int _n, const size_t _sn, const int _d, + const size_t _sd, const int _h, const size_t _sh, const int _w, + const size_t _sw); +} + +template +struct nhwc_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nhwc : format_tag::ndhwc; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nhwc_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + void array_div_by_const(const int n, const data_t *src, const size_t num, + data_t *dst) const; + void array_add(const int n, const data_t *src, data_t *dst) const; + + template + void array_nhwc_max(const int n, data_t *dst, const data_t *src, + unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt, + const int index) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < n; ++oc) { + auto s = src[oc]; + data_t mv = dst[oc]; + + // update index of maximum +#if defined __INTEL_COMPILER + if ((use_workspace) && (s > mv)) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + ws[ws_offset + oc] = index; + } else + reinterpret_cast(ws)[ws_offset + oc] = index; + } +#else + // Need to add explicit predicates for GCC to vectorize this. + // And although the resulting code is ugly, it is still 4 times + // faster than scalar + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + unsigned char predicate = (s > mv) ? 0xff : 0; + unsigned char current_value = ws[ws_offset + oc]; + current_value = (predicate & (unsigned char)index) + | ((~predicate) & current_value); + ws[ws_offset + oc] = current_value; + } else { + auto wint = reinterpret_cast(ws); + unsigned int predicate = (s > mv) ? 0xffffffff : 0; + unsigned int current_value = wint[ws_offset + oc]; + current_value = (predicate & (unsigned int)index) + | ((~predicate) & current_value); + wint[ws_offset + oc] = current_value; + } + } +#endif + // update maximum + dst[oc] = nstl::max(s, mv); + } + } + + template + void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws, + const size_t ws_offset, const data_type_t ws_dt) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + for (int oc = 0; oc < n; ++oc) { + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + ws[ws_offset + oc] = 0; + } else + reinterpret_cast(ws)[ws_offset + oc] = 0; + } + dst[oc] = nstl::numeric_limits::lowest(); + } + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct nhwc_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + nhwc_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +}// namespace cpu +}// namespace impl +}// namespace mkldnn + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp new file mode 100644 index 0000000000..e20333e66f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp @@ -0,0 +1,288 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "nspc_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void nspc_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + const bool with_relu = pd()->with_relu_post_op(); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_mean = scratchpad.get(key_bnorm_tmp_mean); + auto tmp_var = scratchpad.get(key_bnorm_tmp_var); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = tmp_mean; + variance = tmp_var; + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->H() * pd()->W() * pd()->D(); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; + data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; + + if (calculate_stats) { + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] = 0.; + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] += src[(size_t)n * SP * C + + sp * C + c]; + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + mean[c] = 0; + for (dim_t n = 0; n < nthr; n++) + mean[c] += ws_reduce[C * n + c]; + mean[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + mean_loc[c] = mean[c]; + ws_reduce[C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) { + data_t m = src[(size_t)n * SP * C + sp * C + c] + - mean_loc[c]; + ws_reduce[C * ithr + c] += m * m; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + variance[c] = 0; + for (dim_t n = 0; n < nthr; n++) + variance[c] += ws_reduce[C * n + c]; + variance[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) + variance_loc[c] = variance[c]; + } else { + variance_loc = variance; + mean_loc = mean; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + data_t sqrt_variance = static_cast( + sqrtf(variance_loc[c] + eps)); + data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + c] : 0; + size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void nspc_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_diff_ss = scratchpad.get(key_bnorm_tmp_diff_ss); + + if (diff_scaleshift == nullptr) + diff_scaleshift = tmp_diff_ss; + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->D() * pd()->H() * pd()->W(); + data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C; + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + + data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr; + data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr); + + for (dim_t c = 0; c < C; c++) { + ws_reduce[C * ithr + c] = 0.; + ws_reduce[C * nthr + C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd; + ws_reduce[C * nthr + C * ithr + c] += dd; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[c] + eps)); + diff_gamma[c] = 0; + diff_beta[c] = 0; + for (dim_t n = 0; n < nthr; n++) { + diff_gamma[c] += ws_reduce[C * n + c]; + diff_beta[c] += ws_reduce[C * nthr + C * n + c]; + } + diff_gamma[c] *= sqrt_variance; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + diff_gamma_loc[c] = diff_gamma[c]; + diff_beta_loc[c] = diff_beta[c]; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t gamma = use_scaleshift ? scaleshift[c] : 1; + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[c] + eps)); + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_loc[c] / (SP * N) + + (src[d_off] - mean[c]) * diff_gamma_loc[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp new file mode 100644 index 0000000000..aad86b05a7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp @@ -0,0 +1,169 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP +#define CPU_NSPC_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct nspc_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + dim_t sz = nstl::max(C(), 16) * mkldnn_get_max_threads(); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz); + } + } + }; + + typedef typename prec_traits::type data_t; + + nspc_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct nspc_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*diff_src_md(), format_tag::nhwc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C() + * (mkldnn_get_max_threads() + 1)); + } + }; + + typedef typename prec_traits::type data_t; + + nspc_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp new file mode 100644 index 0000000000..d79b1a034b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp @@ -0,0 +1,265 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "ref_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + /* fast return */ + if (this->pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN); + auto variance = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift();; + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op = [&](float res) { + return (with_relu && res < 0.0f) ? 0.0f : res; + }; + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + float v_mean = calculate_stats ? 0 : mean[c]; + float v_variance = calculate_stats ? 0 : variance[c]; + + if (calculate_stats) { + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) + v_mean += src[data_offset(data_d, n, c, d, h, w)]; + v_mean /= W*N*H*D; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean; + v_variance += m*m; + } + v_variance /= W*H*N*D; + } + + float sqrt_variance = sqrtf(v_variance + eps); + float sm = (use_scaleshift + ? scaleshift[scaleshift_d.off(0, c)] + : 1.0f) / sqrt_variance; + float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + auto d_off = data_offset(data_d,n,c,d,h,w); + float bn_res = sm * ((float)src[d_off] - v_mean) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + if (data_type == data_type::s8) { + dst[d_off] = qz_a1b0()(maybe_post_op(bn_res)); + } else { + dst[d_off] = static_cast(maybe_post_op(bn_res)); + } + } + + if (calculate_stats) { + if (save_stats) { + mean[c] = v_mean; + variance[c] = v_variance; + } + } + }); +} + +template struct ref_batch_normalization_fwd_t; +template struct ref_batch_normalization_fwd_t; + +template +void ref_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md()); + + const dim_t C = pd()->C(); + + /* fast return */ + if (this->pd()->has_zero_dim_memory()) { + if (diff_scaleshift) { + for (dim_t c = 0; c < C; ++c) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0; + } + } + return; + } + + const dim_t N = pd()->MB(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + data_t v_mean = mean[c]; + data_t v_variance = variance[c]; + data_t sqrt_variance = static_cast(1.0f / sqrtf(v_variance + eps)); + data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1; + data_t diff_gamma = data_t(0); + data_t diff_beta = data_t(0); + diff_gamma = 0.0; + diff_beta = 0.0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + diff_gamma += (src[s_off] - v_mean) * dd; + diff_beta += dd; + } + diff_gamma *= sqrt_variance; + + if (diff_scaleshift) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta; + } + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w); + data_t dd = diff_dst[dd_off]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + data_t v_diff_src = dd; + if (calculate_diff_stats) { + v_diff_src -= diff_beta/(D*W*H*N) + + (src[s_off] - v_mean) * + diff_gamma*sqrt_variance/(D*W*H*N); + } + v_diff_src *= gamma*sqrt_variance; + diff_src[dd_off] = v_diff_src; + } + }); +} + +template struct ref_batch_normalization_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp new file mode 100644 index 0000000000..aa9f74125a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_BATCH_NORMALIZATION_HPP +#define CPU_REF_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && IMPLICATION(use_scaleshift(), + weights_md()->data_type == data_type::f32) + && (attr()->has_default_values() || with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (src_md()->data_type == data_type::s8 && !stats_is_src()) + return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + return status::success; + } + }; + + ref_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t); + + status_t init() { + bool ok = true + && is_bwd() + && utils::everyone_is(data_type, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), utils::everyone_is(data_type, + weights_md()->data_type, + diff_weights_md()->data_type)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp new file mode 100644 index 0000000000..4c534b5508 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_CONCAT_HPP +#define REF_CONCAT_HPP + +#include "reorder_pd.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_CONCAT_PD_T("ref:any", ref_concat_t); + + status_t init() { + bool ok = cpu_concat_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + const primitive_attr_t attr; /* alpha == 1. */ + reorder_pd_t *r_pd = nullptr; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, src_image_md(i)) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector reorder_pds_; + }; + + ref_concat_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_concat_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp new file mode 100644 index 0000000000..c0a979c4cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp @@ -0,0 +1,395 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "type_helpers.hpp" + +#include "ref_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template +void ref_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool with_relu = 0; // TODO: change if support post_ops + const float nslope = 0.f; + + const int ndims = pd()->desc()->src_desc.ndims; + + auto ker = [=](int g, int mb, int oc, int od, int oh, + int ow) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * KSD - padFront + kd * (1 + KDD); + const int ih = oh * KSH - padT + kh * (1 + KDH); + const int iw = ow * KSW - padL + kw * (1 + KDW); + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + if (ndims == 5) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + + } + return d; + }; + + parallel_nd(G, MB, OC, OD, OH, OW, + [&](int g, int mb, int oc, int od, int oh, int ow) { + float a = bias + ? get_bias(bias, bias_d.off(g * OC + oc), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, oc, od, oh, ow); + if (with_relu && a < 0) + a = a * nslope; + if (ndims == 5) + dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate(a); + else if (ndims == 4) + dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate(a); + else if (ndims == 3) + dst[dst_d.off(mb, g*OC + oc, ow)] = saturate(a); + else + assert(false); + }); +} + +template +void ref_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->diff_src_desc.ndims; + + auto ker = [=](int g, int mb, int ic, int id, int ih, + int iw) { + acc_data_t d = 0; + for (int oc = 0; oc < OC; ++oc) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + if (iw + padL < kw * (1 + KDW) + || ih + padT < kh * (1 + KDH) + || id + padFront < kd * (1 + KDD)) + continue; + int ow = iw - kw * (1 + KDW) + padL; + int oh = ih - kh * (1 + KDH) + padT; + int od = id - kd * (1 + KDD) + padFront; + if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) + continue; + + ow /= KSW; + oh /= KSH; + od /= KSD; + + if (od < OD && oh < OH && ow < OW) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, od, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + } + } + return d; + }; + + parallel_nd(G, MB, IC, ID, IH, IW, + [&](int g, int mb, int ic, int id, int ih, int iw) { + auto ds_idx = (ndims == 5) + ? diff_src_d.off(mb, g*IC + ic, id, ih, iw) + : (ndims == 4) + ? diff_src_d.off(mb, g*IC + ic, ih, iw) + : diff_src_d.off(mb, g*IC + ic, iw); + float a = bias + ? get_bias(bias, bias_d.off(g * IC + ic), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, ic, id, ih, iw); + diff_src[ds_idx] = saturate(a); + }); +} + +template +void ref_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->src_desc.ndims; + +auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ow*KSW + kw * (1 + KDW) < padL + || oh*KSH + kh * (1 + KDH) < padT + || od*KSD + kd * (1 + KDD) < padFront + || ow*KSW + kw * (1 + KDW) >= IW + padL + || oh*KSH + kh * (1 + KDH) >= IH + padT + || od*KSD + kd * (1 + KDD) >= ID + padFront) + continue; + + int id = od*KSD - padFront + kd * (1 + KDD); + int ih = oh*KSH - padT + kh * (1 + KDH); + int iw = ow*KSW - padL + kw * (1 + KDW); + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, + oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)] + * src[src_d.off(mb, g*IC + ic, ih, iw)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)] + * src[src_d.off(mb, g*IC + ic, iw)]; + else + assert(false); + } + }; + + auto ker_bias = [=](acc_data_t &d, int g, int oc) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh, + ow)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, + ow)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]; + else + assert(false); + } + }; + + parallel_nd(G, OC, [&](int g, int oc) { + if (diff_bias) { + // XXX: loss of precision when bias is a float... + acc_data_t db = 0; + ker_bias(db, g, oc); + diff_bias[diff_bias_d.off(g*OC+oc)] + = saturate(db); + } + + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t dw = 0; + ker(dw, g, oc, ic, kd, kh, kw); + + if (ndims == 5) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kd, kh, kw) + : diff_weights_d.off(oc, ic, kd, kh, kw); + diff_weights[idx] = saturate(dw); + } else if (ndims == 4) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kh, kw) + : diff_weights_d.off(oc, ic, kh, kw); + diff_weights[idx] = saturate(dw); + } else if (ndims == 3) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kw) + : diff_weights_d.off(oc, ic, kw); + diff_weights[idx] = saturate(dw); + } else { + assert(false); + } + } + }); +} + +using namespace data_type; + +template struct ref_convolution_fwd_t; + +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; + +template struct ref_convolution_bwd_data_t; + +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; + +template struct ref_convolution_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp new file mode 100644 index 0000000000..7c83d0c6d4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_CONVOLUTION_HPP +#define CPU_REF_CONVOLUTION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, data_type::undef, + dst_type, acc_type) + && IMPLICATION(with_bias(), true + && IMPLICATION(src_type == u8, + utils::one_of(bias_md_.data_type, f32, s32, s8, u8)) + && IMPLICATION(src_type == f32, + bias_md_.data_type == f32)) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; + } + + virtual bool support_bias() const override { return true; } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + using cpu_convolution_bwd_weights_pd_t::cpu_convolution_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_wei_type, diff_wei_type, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type diff_wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp new file mode 100644 index 0000000000..541a303aab --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp @@ -0,0 +1,199 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_deconvolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void ref_deconvolution_fwd_t::compute_fwd_bias(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OD = pd()->OD(); + const int OC = pd()->OC() / G; + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(MB, G, OC, OD, OH, OW, + [&](int mb, int g, int oc, int od, int oh, int ow) { + auto b = bias[g * OC + oc]; + switch (ndims) { + case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break; + case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break; + case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break; + default: assert(!"invalid dimension size"); + } + }); +} + +void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW()*pd()->OH()*pd()->OD(); + + parallel_nd(MB, OC, [&](int mb, int oc) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + dst[offset] += bias[oc]; + } + }); +} + +template +void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW() * pd()->OH() * pd()->OD(); + + const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0]; + + parallel_nd(MB, utils::div_up(OC, blksize), SP, + [&](int mb, int oc_blk, int sp) { + int oc = oc_blk * blksize; + auto offset = mb * stride_mb + oc * SP + sp * blksize; + const int blk = nstl::min(blksize, OC - oc); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + dst[offset + i] += bias[oc + i]; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias(const data_t *diff_dst, + data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OC = pd()->OC() / G; + const int OD = pd()->OD(); + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(G, OC, [&](int g, int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + switch (ndims) { + case 5: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, od, oh, ow)]; + break; + case 4: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, oh, ow)]; + break; + case 3: + db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)]; + break; + default: assert(!"invalid dimension size"); + } + } + } + } + } + diff_bias[g * OC + oc] = db; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH()*pd()->OW()*pd()->OD(); + + parallel_nd(OC, [&](int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + db += diff_dst[offset]; + } + } + diff_bias[oc] = db; + }); +} + +template +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH() * pd()->OW() * pd()->OD(); + + const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0]; + + parallel_nd(utils::div_up(OC, blksize), [&](int ocb) { + data_t db[blksize] = {0}; + + for (int mb = 0; mb < MB; ++mb) { + for (int sp = 0; sp < SP; ++sp) { + auto offset = mb * stride_mb + (ocb * SP + sp) * blksize; + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; ++i) + db[i] += diff_dst[offset+i]; + } + } + + const int blk = nstl::min(blksize, OC - ocb * blksize); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + diff_bias[ocb * blksize + i] = db[i]; + }); +} + +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp new file mode 100644 index 0000000000..d61903c32d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp @@ -0,0 +1,502 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_DECONVOLUTION_HPP +#define CPU_REF_DECONVOLUTION_HPP + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static status_t compute_blocked_format(bool with_groups, + const memory_desc_t *oi_md, memory_desc_t *io_md) +{ + /* Computes blocking for *i*o* format from *o*i* format */ + + bool sanity_check_ok = true + && oi_md->ndims == io_md->ndims + && oi_md->format_kind == format_kind::blocked; + if (!sanity_check_ok) return status::invalid_arguments; + + const blocking_desc_t &oi_blk = oi_md->format_desc.blocking; + blocking_desc_t io_blk = io_md->format_desc.blocking; + + io_md->format_kind = format_kind::blocked; + io_blk = oi_blk; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]); + for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) { + if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) { + io_blk.inner_idxs[i_blk] = + (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC); + } + } + + return memory_desc_init_by_blocking_desc(*io_md, io_blk); +} + +static status_t conv_descr_create(const deconvolution_desc_t *dd, + convolution_desc_t *cd) +{ + using namespace prop_kind; + alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct + ? alg_kind::convolution_direct : alg_kind::convolution_winograd; + + const memory_desc_t *src_md, *dst_md, *d_weights_d; + prop_kind_t prop_kind; + memory_desc_t c_weights_d; + if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) { + prop_kind = backward_data; + src_md = &dd->dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->weights_desc; + } else if (dd->prop_kind == backward_data) { + prop_kind = forward_training; + src_md = &dd->diff_dst_desc; + dst_md = &dd->diff_src_desc; + d_weights_d = &dd->weights_desc; + } else { + prop_kind = dd->prop_kind; + src_md = &dd->diff_dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->diff_weights_desc; + } + + const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; + + /* create weights desc for convolution */ + c_weights_d = *d_weights_d; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]); + nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]); + nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]); + + if (c_weights_d.format_kind != format_kind::any) + CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d)); + + return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, + prop_kind != backward_weights ? &dd->bias_desc : nullptr, + dst_md, dd->strides, dd->dilates, + dd->padding[0], dd->padding[1], dd->padding_kind); +} + +struct ref_deconvolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , conv_supports_bias_(other.conv_supports_bias_) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + CHECK(conv_descr_create(desc(), &cd)); + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + conv_supports_bias_ = + static_cast(conv_pd_) + ->support_bias(); + bool output_f32 = utils::everyone_is(data_type::f32, + desc()->accum_data_type, desc()->dst_desc.data_type); + + bool ok = true + && conv_pd_->weights_md()->extra.flags == 0 + /* deconv reference code can process only f32 bias */ + && IMPLICATION(with_bias(), + conv_supports_bias_ || output_f32); + if (ok) return status::success; + + delete conv_pd_; + } + conv_pd_ = nullptr; + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->post_ops_.has_default_values(); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (dst_md_.format_kind == format_kind::any) + dst_md_ = *conv_pd_->diff_src_md(); + if (bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + bool conv_supports_bias_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_fwd_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + if (pd()->with_bias() && pd()->conv_supports_bias_) + conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS); + conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + + if (pd()->with_bias() && !pd()->conv_supports_bias_) { + using namespace format_tag; + + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_fwd_bias_ncdhw(bias, dst); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_fwd_bias_nCdhwXc<8>(bias, dst); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_fwd_bias_nCdhwXc<16>(bias, dst); + break; + default: + compute_fwd_bias(bias, dst); + break; + } + } + return status::success; + } + +private: + void compute_fwd_bias(const data_t *bias, data_t *dst) const; + void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const; + template void compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_data_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_data_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + + return status::unimplemented; + } + + status_t init() { + using namespace data_type; + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && utils::everyone_is(data_type::f32, + desc()->diff_src_desc.data_type, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (diff_src_md_.format_kind == format_kind::any) + diff_src_md_ = *conv_pd_->dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_data_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_weights_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->diff_weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type::f32, + desc()->src_desc.data_type, + desc()->diff_weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->has_default_values(); + if (ok) { + CHECK(init_convolution()); + if (diff_weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->diff_weights_md(), + &desc_.diff_weights_desc)); + diff_weights_md_ = desc_.diff_weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + if (diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_weights_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + status_t status = conv_p_->execute(conv_ctx); + if (status != status::success) return status; + + if (pd()->with_bias()) { + using namespace format_tag; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_bwd_bias_ncdhw(diff_dst, diff_bias); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias); + break; + default: + compute_bwd_bias(diff_dst, diff_bias); + break; + } + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const; + void compute_bwd_bias_ncdhw(const data_t *diff_dst, + data_t *diff_bias) const; + template void compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const; + + primitive_t *conv_p_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp new file mode 100644 index 0000000000..7beee8d323 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp @@ -0,0 +1,297 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace alg_kind; +using namespace math; + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, + float beta): alg_(alg), alpha_(alpha), beta_(beta) { + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); +} + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( + const post_ops_t::entry_t::eltwise_t &eltwise) + : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {} + +float ref_eltwise_scalar_fwd_t::compute_scalar(float s) { + switch (alg_) { + case eltwise_relu: return relu_fwd(s, alpha_); + case eltwise_tanh: return tanh_fwd(s); + case eltwise_elu: return elu_fwd(s, alpha_); + case eltwise_square: return square_fwd(s); + case eltwise_abs: return abs_fwd(s); + case eltwise_sqrt: return sqrt_fwd(s); + case eltwise_linear: return linear_fwd(s, alpha_, beta_); + case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_); + case eltwise_soft_relu: return soft_relu_fwd(s); + case eltwise_logistic: return logistic_fwd(s); + default: assert(!"unknown eltwise alg_kind"); + } + + return 0.f; +} + +template +void ref_eltwise_fwd_t::execute_forward_nCspBc_padded( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + const blocking_desc_t &blk = data_d.blocking_desc(); + const int block = blk.inner_blks[0]; + + const int MB = pd()->MB(); + const int C = pd()->C() / block; + const int C_PADDED = data_d.padded_dims()[1] / block; + const int tail = pd()->C() % block; + const int SP = pd()->D() * pd()->H() * pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + auto ker = [=] (data_t &d, data_t s) { + switch (alg_kind) { + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }; + + // FIXME: integer overflow? + + parallel_nd(MB, C_PADDED, SP, + [&](int n, int c, int sp) { + auto d_off = (n*C_PADDED*SP + c*SP + sp) * block; + if (c < C) { + for (int v = 0; v < block; v++) + ker(dst[d_off + v], src[d_off + v]); + } else { + for (int v = 0; v < tail; v++) + ker(dst[d_off + v], src[d_off + v]); + } + }); +} + +template +void ref_eltwise_fwd_t::execute_forward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int id, int h, int w) { + auto d_off = is_3d + ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w); + data_t s = src[d_off]; + data_t &d = dst[d_off]; + switch (alg_kind) { + case eltwise_relu: d = relu_fwd(s, alpha); break; + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_fwd_t::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const ptrdiff_t nelems = static_cast(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + dst += data_d.offset0(); + + if (alg_kind == eltwise_relu) { + // a fast path for relu as the most popular activation + parallel_nd(nelems, [&](ptrdiff_t e) { + dst[e] = relu_fwd(src[e], alpha); + }); + return; + } + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t s = src[e]; + data_t &d = dst[e]; + + switch (alg_kind) { + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_bwd_t::execute_backward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int d, int h, int w) { + auto data_off = is_3d + ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w); + auto diff_data_off = is_3d + ? diff_data_d.off(n, c, d, h, w) + : diff_data_d.off(n, c, h, w); + data_t s = src[data_off]; + data_t dd = diff_dst[diff_data_off]; + data_t &ds = diff_src[diff_data_off]; + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: + ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: + ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_bwd_t::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const ptrdiff_t nelems = static_cast(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t dd = diff_dst[e]; + const data_t s = src[e]; + data_t &ds = diff_src[e]; + + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; + +template struct ref_eltwise_bwd_t; +template struct ref_eltwise_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp new file mode 100644 index 0000000000..8f4ab35413 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp @@ -0,0 +1,168 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_ELTWISE_HPP +#define CPU_REF_ELTWISE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_eltwise_scalar_fwd_t { +public: + ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta); + + // note that eltwise.scale is ignored + ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise); + + float compute_scalar(float s); + + const alg_kind_t alg_; + const float alpha_; + const float beta_; +}; + +template +struct ref_eltwise_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t); + + status_t init() { + using namespace utils; + + auto src_d = memory_desc_wrapper(src_md()); + + use_dense_ = false + || src_d.is_dense() + || (src_d.is_dense(true) && is_zero_preserved()); + + use_nCspBc_padded_ = !use_dense_ + && src_d.blocking_desc().inner_nblks == 1 + && one_of(src_d.blocking_desc().inner_blks[0], 8, 16) + && src_d.blocking_desc().inner_idxs[0] == 1 + && src_d.only_padded_dim(1) + && src_d.is_dense(true); + + if (has_zero_dim_memory()) + use_dense_ = use_nCspBc_padded_ = false; + + const bool use_generic = !use_dense_ && !use_nCspBc_padded_; + + bool ok = true + && is_fwd() + && everyone_is(data_type, desc()->data_desc.data_type) + && IMPLICATION(use_generic, one_of(src_d.ndims(), 4, 5)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + + bool use_dense_, use_nCspBc_padded_; + }; + + ref_eltwise_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_forward_dense(ctx); + else if (pd()->use_nCspBc_padded_) + execute_forward_nCspBc_padded(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const; + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_eltwise_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && !is_fwd() + && everyone_is(data_type, + desc()->data_desc.data_type, + desc()->diff_data_desc.data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + auto diff_dst_d = memory_desc_wrapper(diff_dst_md()); + const bool same_fmt_ = diff_dst_d == memory_desc_wrapper(src_md()); + + use_dense_ = true + && same_fmt_ + && diff_dst_d.is_dense(true) + && is_zero_preserved() + && !has_zero_dim_memory(); + const bool use_generic = !use_dense_; + + if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5)) + return status::unimplemented; + + return status::success; + } + + bool use_dense_; + }; + + ref_eltwise_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp new file mode 100644 index 0000000000..c807a9ffd0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp @@ -0,0 +1,285 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template +void ref_inner_product_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5); + const int ndims = src_d.ndims() - 2; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + auto ker_has_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int ic = 0; ic < IC; ++ic) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + switch (ndims) { + case 3: + d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)] + * weights[weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)] + * weights[weights_d.off(oc, ic, kh, kw)]; + break; + case 1: + d += (acc_data_t)src[src_d.off(mb, ic, kw)] + * weights[weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + return d; + }; + + auto ker_no_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) { + d += (acc_data_t)src[src_d.off(mb, ic)] + * weights[weights_d.off(oc, ic)]; + } + return d; + }; + + parallel_nd(MB, OC, [&](int mb, int oc) { + float a = bias + ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type) + : 0; + if (src_has_spatial) + a += ker_has_spatial(mb, oc); + else + a += ker_no_spatial(mb, oc); + if (do_relu && a < (acc_data_t)0) + a *= nslope; + dst[dst_d.off(mb, oc)] = saturate(a); + }); +} + +using namespace data_type; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; + +template +void ref_inner_product_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool diff_src_has_spatial + = utils::one_of(diff_src_d.ndims(), 3, 4, 5); + const int ndims = diff_src_d.ndims() - 2; + + parallel_nd(MB, IC, [&](int mb, int ic) { + if (diff_src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + switch (ndims) { + case 3: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kd, kh, kw)]); + break; + case 2: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kh, kw)]); + break; + case 1: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kw)]); + break; + default: assert(!"unsupported ndims size"); + } + } + switch (ndims) { + case 3: + diff_src[diff_src_d.off(mb, ic, kd, kh, kw)] + = (diff_src_data_t)ds; + break; + case 2: + diff_src[diff_src_d.off(mb, ic, kh, kw)] + = (diff_src_data_t)ds; + break; + case 1: + diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds; + break; + default: assert(!"unsupported ndims size"); + } + } + } else { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] * + weights[weights_d.off(oc, ic)]); + } + diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds; + } + }); +} + +template struct ref_inner_product_bwd_data_t; + +template +void ref_inner_product_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5); + const int ndims = src_d.ndims() - 2; + + parallel_nd(OC, IC, [&](int oc, int ic) { + if (src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + data_t *dw(nullptr); + switch (ndims) { + case 3: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kh, kw)]; + break; + case 1: + dw = &diff_weights[diff_weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + switch (ndims) { + case 3: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kd, kh, kw)]; + break; + case 2: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kh, kw)]; + break; + case 1: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + } else { + data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)]; + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + *dw += diff_dst[diff_dst_d.off(mb, oc)] * + src[src_d.off(mb, ic)]; + } + } + }); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + + parallel_nd(OC, [&](int oc) { + data_t *db = &diff_bias[oc]; + *db = data_t(0); + for (int mb = 0; mb < MB; ++mb) + *db += diff_dst[diff_dst_d.off(mb, oc)]; + }); + } +} + +template struct ref_inner_product_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp new file mode 100644 index 0000000000..bf87dbd514 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_INNER_PRODUCT_HPP +#define CPU_REF_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && src_md()->data_type == src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && dst_md()->data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && diff_src_md()->data_type == diff_src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && diff_dst_md()->data_type == diff_dst_type + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type, + src_md()->data_type, + diff_dst_md()->data_type, + diff_weights_md()->data_type) + && IMPLICATION(with_bias(), + data_type == diff_weights_md(1)->data_type) + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp new file mode 100644 index 0000000000..325e97963b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp @@ -0,0 +1,252 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static inline float fast_negative_powf(float omega, float beta) { + float Y; +/* + * Y = omega^(-3/4) = + * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega)) + * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega) / omega) + * = sqrtf(1.0f / (sqrtf(omega) * omega)) + */ + if (beta == 0.75f) { + Y = sqrtf(1.0f / (sqrtf(omega) * omega)); + } else { + Y = 1.0f / powf(omega, beta); + } + return Y; +}; + +template +template +void ref_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize + + h * W * blksize + w * blksize + c % blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const float alpha = static_cast(pd()->desc()->lrn_alpha); + const float beta = static_cast(pd()->desc()->lrn_beta); + const float k = static_cast(pd()->desc()->lrn_k); + + const int size = pd()->desc()->local_size; + const int half_size = (size - 1) / 2; + + float sum = 0; + if (across_channels) { + const int c_st = nstl::max(oc - half_size + 0, 0); + const int c_en = nstl::min(oc + half_size + 1, C); + + for (int c = c_st; c < c_en; ++c) { + const float s = src[data_off(mb, c, oh, ow)]; + sum += s * s; + } + } else { + int h_st = nstl::max(oh - half_size + 0, 0); + int h_en = nstl::min(oh + half_size + 1, H); + int w_st = nstl::max(ow - half_size + 0, 0); + int w_en = nstl::min(ow + half_size + 1, W); + for (int h = h_st; h < h_en; ++h) { + for (int w = w_st; w < w_en; ++w) { + const float s = src[data_off(mb, oc, h, w)]; + sum += s * s; + } + } + } + const int summands = across_channels ? size : size * size; + sum = k + alpha * sum / summands; + size_t off = data_off(mb, oc, oh, ow); + d[0] = static_cast(src[off] * fast_negative_powf(sum, beta)); + }; + + const int MB = pd()->MB(); + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&dst[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&dst[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&dst[off], mb, c, h, w); + }); + } +} + +template +template +void ref_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + const float alpha = static_cast(pd()->desc()->lrn_alpha); + const float beta = static_cast(pd()->desc()->lrn_beta); + const float k = static_cast(pd()->desc()->lrn_k); + const int kernel_size = pd()->desc()->local_size; + const int half_ksize = (kernel_size - 1) / 2; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize + + h * W * blksize + w * blksize + c%blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const int c_st = nstl::max(oc - half_ksize + 0, 0); + const int c_en = nstl::min(oc + half_ksize + 1, C); + + float A = 0, B = 0, omega_mid = 0; + for (int c = c_st; c < c_en; c++) { + float sum = 0.0; + const int i_st = nstl::max(c - half_ksize, 0); + const int i_en = nstl::min(c + kernel_size - half_ksize, C); + + for (int i = i_st; i < i_en; ++i) { + const float value = src[data_off(mb, i, oh, ow)]; + sum += value * value; + } + const float omega = static_cast(k + sum * alpha / kernel_size); + if (c == oc) omega_mid = omega; + float t = src[data_off(mb, c, oh, ow)] + * fast_negative_powf(omega, beta); + B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)]; + } + + const size_t off = data_off(mb, oc, oh, ow); + A = fast_negative_powf(omega_mid, beta) * diff_dst[off]; + B *= src[off]; + B *= (2.0f * alpha * beta) / kernel_size; + *d = static_cast(A - B); // final cast down to data_t + }; + + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&diff_src[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&diff_src[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&diff_src[off], mb, c, h, w); + }); + } +} + +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp new file mode 100644 index 0000000000..f25cfb7fae --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_LRN_HPP +#define CPU_REF_LRN_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t); + + status_t init() { + using namespace format_tag; + + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_forward(ctx); break; + case nChw8c: execute_forward(ctx); break; + case nchw: execute_forward(ctx); break; + case nhwc: execute_forward(ctx); break; + default: execute_forward(ctx); + } + return status::success; + } + +private: + template + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t); + + status_t init() { + using namespace format_tag; + using namespace alg_kind; + + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, lrn_across_channels + /*, lrn_within_channel */) // not supported yet + && utils::everyone_is(data_type, + src_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_backward(ctx); break; + case nChw8c: execute_backward(ctx); break; + case nchw: execute_backward(ctx); break; + case nhwc: execute_backward(ctx); break; + default: execute_backward(ctx); + } + return status::success; + } + +private: + template + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp new file mode 100644 index 0000000000..65b934e123 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp @@ -0,0 +1,381 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "ref_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t offset = is_3d + ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[offset] = value; + } else + reinterpret_cast(ws)[offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int oc, int oh, int ow) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, 1, oh, ow, kh*KW + kw); + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + acc_data_t dst = 0; + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, ih, iw)]; + } + } + + d[0] = math::out_round((float)dst / num_summands); + }; + + auto ker_max_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, id, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + acc_data_t dst = 0; + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, id, ih, iw)]; + } + } + } + + d[0] = math::out_round((float)dst / num_summands); + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (alg == pooling_max) { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = nstl::numeric_limits::lowest(); + set_ws(mb, oc, od, oh, ow, 0); + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + }); + } else { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = 0; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + }); + } +} + +template +void ref_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int _mb, int _oc) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, ih, iw)] = data_type_t(0); + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int oc, int oh, int ow) { + const size_t ws_off = ws_d.off(mb, oc, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = index / KW; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0] / num_summands; + } + } + }; + + auto ker_zero_3d = [=](int _mb, int _oc) { + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, id, ih, iw)] = + data_type_t(0); + } + } + } + }; + + auto ker_max_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + const size_t ws_off = ws_d.off(mb, oc, od, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0]; + }; + + auto ker_avg_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0] / num_summands; + } + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (pd()->desc()->alg_kind == alg_kind::pooling_max) { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + } + } + } + }); + } +} + +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; + +template struct ref_pooling_bwd_t; +template struct ref_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp new file mode 100644 index 0000000000..e43ceaa82b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_POOLING_HPP +#define CPU_REF_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && desc()->accum_data_type == acc_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + ref_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::everyone_is(data_type, diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp new file mode 100644 index 0000000000..af27743110 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp @@ -0,0 +1,153 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_shuffle.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace format_tag; + +template +template +void ref_shuffle_t::execute_(const exec_ctx_t &ctx) const { + using namespace prop_kind; + using namespace utils; + + const memory_desc_wrapper data_d(pd()->data_md()); + + auto i_arg = pd()->is_fwd() ? MKLDNN_ARG_SRC : MKLDNN_ARG_DIFF_DST; + auto o_arg = pd()->is_fwd() ? MKLDNN_ARG_DST : MKLDNN_ARG_DIFF_SRC; + auto input = CTX_IN_MEM(const data_t *, i_arg); + auto output = CTX_OUT_MEM(data_t *, o_arg); + + const int axis = pd()->axis(); + const int axis_size = pd()->axis_size(); + + const int MB = pd()->MB(); + const int C = pd()->C(); + int H = 1, W = 1, D = 1, HW = 1, SP = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4 ,5); + if (has_spatial) + { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + HW = H * W; + SP = D * HW; + } + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = one_of(tag, nChw16c, nCdhw16c) ? 16 : 8; + + if (axis == 1 && one_of(tag, nChw16c, nChw8c, nCdhw16c, nCdhw16c)) { +#if MKLDNN_THR == MKLDNN_THR_OMP +# pragma omp parallel for collapse(3) schedule(static) + for (int mb = 0; mb < MB; ++mb) + for (int cb = 0; cb < C; cb += blksize) + for (int sp = 0; sp < SP; ++sp) { + const size_t off = mb * stride_mb + sp * blksize; + const size_t output_off = off + cb * SP; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + } +#else + parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c, + int sp) { + const size_t off = mb * stride_mb + sp * blksize; + const int cb = c * blksize; + const size_t output_off = off + cb * SP; + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + }); +#endif + } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) { + parallel_nd(MB, SP, [&](int mb, int sp) { + const size_t off = mb * stride_mb + sp * C; + PRAGMA_OMP_SIMD() + for (int c = 0; c < C; ++c) + output[off + c] = input[off + rev_transposed_[c]]; + }); + } else if (axis == 1 && one_of(tag, nchw, ncdhw)) { + parallel_nd(MB, C, [&](int mb, int c) { + const size_t output_off = mb * stride_mb + c * SP; + const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP; + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + output[output_off + sp] = input[input_off + sp]; + } + }); + } else { + auto dims = pd()->desc()->data_desc.dims; + auto ndims = pd()->desc()->data_desc.ndims; + const size_t outer_size = utils::array_product(dims, axis); + const size_t inner_size = utils::array_product(dims + axis + 1, + ndims - axis - 1); + const size_t dim = axis_size * inner_size; + + parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a, + size_t in) + { + const size_t off = ou * dim + in; + auto &o = output[data_d.off_l(off + a * inner_size)]; + o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)]; + }); + } +} + +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; + +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp new file mode 100644 index 0000000000..5e09a1a69b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp @@ -0,0 +1,111 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SHUFFLE_HPP +#define CPU_REF_SHUFFLE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_shuffle_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_shuffle_t : public cpu_primitive_t { + using shuffle_class = ref_shuffle_t; + + struct pd_t: public cpu_shuffle_pd_t { + using cpu_shuffle_pd_t::cpu_shuffle_pd_t; + + DECLARE_COMMON_PD_T("ref:any", shuffle_class); + + status_t init() { + using namespace format_tag; + + bool ok = true + && data_type_size + == types::data_type_size(data_md()->data_type); + if (!ok) return status::unimplemented; + + if (ndims() == 5) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nCdhw16c, nCdhw8c, ncdhw, ndhwc); + } else if (ndims() == 4) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nChw16c, nChw8c, nchw, nhwc); + } else + dat_tag_ = any; + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_shuffle_t(const pd_t *apd): cpu_primitive_t(apd) { + const int axis_size = pd()->axis_size(); + const int group_size = pd()->group_size(); + const int transpose_row = pd()->is_fwd() ? group_size + : axis_size / group_size; + const int transpose_col = pd()->is_fwd() ? axis_size / group_size + : group_size; + rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64); + parallel_nd(transpose_col, transpose_row, [&](int i, int j) { + rev_transposed_[j * transpose_col + i] = i * transpose_row + j; + }); + } + + ~ref_shuffle_t() { free(rev_transposed_); } + + typedef typename typesize_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nCdhw16c: execute_(ctx); break; + case nChw16c: execute_(ctx); break; + case nCdhw8c: execute_(ctx); break; + case nChw8c: execute_(ctx); break; + case ncdhw: execute_(ctx); break; + case nchw: execute_(ctx); break; + case ndhwc: execute_(ctx); break; + case nhwc: execute_(ctx); break; + default: execute_(ctx); break; + } + return status::success; + } + +private: + template + void execute_(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int *rev_transposed_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp new file mode 100644 index 0000000000..36d5237f56 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_softmax.hpp" +#include "gemm/os_blas.hpp" + +#ifdef USE_MKL +#include "mkl_vml_functions.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_softmax_fwd_t::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + parallel_nd(outer_size_, [&](int ou) { + const data_t *src_data = src + ou * channels_; + data_t *dst_data = dst + ou * channels_; + data_t scalar = 0; + + _max(channels_, src_data, &scalar); + _sub(channels_, scalar, src_data, dst_data); + _exp(channels_, dst_data, dst_data); + _sum(channels_, dst_data, &scalar); + _scal(channels_, data_t(1)/scalar, dst_data); + }); +} + +template +void ref_softmax_fwd_t::execute_forward_generic( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + data_t space_max_val = 0, space_denom_val = 0; + data_t *space_max = &space_max_val, *space_denom = &space_denom_val; + if (inner_size_ > 1) { + using namespace memory_tracking::names; + space_max = scratchpad(ctx).template get(key_softmax_reduction); + space_denom = space_max + inner_size_; + } + + const memory_desc_wrapper data_d(pd()->src_md()); + const size_t dim = channels_ * inner_size_; + + for (int ou = 0; ou < outer_size_; ou++) { + utils::array_set(space_max, -FLT_MAX, inner_size_); + utils::array_set(space_denom, 0, inner_size_); + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_max[in] = nstl::max(space_max[in], src[off]); + } + } + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_denom[in] += dst[off] = exp(src[off] - space_max[in]); + } + } + + for (int c = 0; c < channels_; c++) { + for (int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + dst[off] /= space_denom[in]; + } + } + } +} + +template +void ref_softmax_fwd_t::_max(int n, const data_t *x, + data_t *max_data) const { +// Intel(R) C++ Compiler generates the maxps + shuffle pattern +// for the max search which works faster +#if !defined(__INTEL_COMPILER) + // The code below makes a compiler to generate maxps instruction + // rather than maxss, which is generated for the 'else' code path + auto max_wrapper = [](data_t a, data_t b) { return nstl::max(a, b); }; + auto min_wrapper = [](int a, int b) { return nstl::min(a, b); }; + + constexpr int unroll_factor = 32; + data_t max_values[unroll_factor]; + + if (n < unroll_factor) { + data_t max_val = x[0]; + for (int i = 1; i < n; i++) { + max_val = max_wrapper(max_val, x[i]); + } + max_data[0] = max_val; + return; + } + for (int i = 0; i < unroll_factor; i++) { + max_values[i] = x[i]; + } + for (int i = unroll_factor; i < n; i += unroll_factor) { + int offset = min_wrapper(i, n - unroll_factor); + for (int j = 0; j < unroll_factor; j++) { + max_values[j] = max_wrapper(max_values[j], x[offset + j]); + } + } + data_t max_val = max_values[0]; + for (int i = 1; i < unroll_factor; i++) { + max_val = max_wrapper(max_val, max_values[i]); + } + max_data[0] = max_val; +#else + max_data[0] = x[0]; + for (int c = 1; c < n; ++c) + max_data[0] = nstl::max(max_data[0], x[c]); +#endif +} + +template +void ref_softmax_fwd_t::_sub(int n, data_t alpha, const data_t *x, + data_t *y) const { + constexpr int unroll_factor = 32; + int tail = n % unroll_factor; + for (int i = 0; i < n - tail; i += unroll_factor) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < unroll_factor; j++) { + y[i + j] = x[i + j] - alpha; + } + } + PRAGMA_OMP_SIMD() + for (int i = n - tail; i < n; i++) { + y[i] = x[i] - alpha; + } +} + +template +void ref_softmax_fwd_t::_exp(int n, const data_t *a, + data_t *r) const { +#ifdef USE_MKL + if (data_type == data_type::f32) { + vsExp(n, a, r); + return; + } +#endif + parallel_nd(n, [&](int c) { r[c] = expf(a[c]); }); +} + +template +void ref_softmax_fwd_t::_sum(int n, const data_t *x, + data_t *sum_data) const { +#ifdef USE_CBLAS + // Here we are summing x's eg. e^z , which are positives + // so we can use BLAS ASUM + if (data_type == data_type::f32) { + sum_data[0] = cblas_sasum(n, x, 1); + return; + } +#endif + data_t tsum = static_cast(0); + PRAGMA_OMP_SIMD(reduction(+ : tsum)) + for (int c = 0; c < n; ++c) + tsum += x[c]; + sum_data[0] = tsum; +} + +template +void ref_softmax_fwd_t::_scal(int n, data_t alpha, data_t *x) const { +#ifdef USE_CBLAS + if (data_type == data_type::f32) { + cblas_sscal(n, alpha, x, 1); + return; + } +#endif + parallel_nd(n, [&](int c) { x[c] *= alpha; }); +} + +template struct ref_softmax_fwd_t; + + +// NC/NCHW softmax for along final axe (1 for NC, 3 for NCHW) +template +void ref_softmax_bwd_t::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + parallel_nd(outer_size_, [&](int ou) { + data_t sbr = 0; + size_t off = channels_*ou; + for (int c = 0; c < channels_; c++) { + size_t loff = off + c; + data_t ldata = dst[loff]; + sbr += diff_dst[loff]*ldata; + diff_src[loff] = ldata; + } + + for(int c=0; c < channels_ ; ++c) { + size_t loff = off + c; + diff_src[loff] *= (diff_dst[loff] - sbr); + } + }); +} + +template +void ref_softmax_bwd_t::execute_backward_generic( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_d(pd()->diff_src_md()); + const memory_desc_wrapper data_d(pd()->dst_md()); + + const size_t dim = channels_ * inner_size_; + + parallel_nd(outer_size_, [&](int ou) { + for (int in = 0; in < inner_size_; in++) { + data_t sbr = 0; + for (int c = 0; c < channels_; c++) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = diff_d.off_l(ou * dim + c * inner_size_ + in); + sbr += diff_dst[off_diff] * dst[off_data]; + } + + for(int c=0; c < channels_ ; ++c) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = data_d.off_l(ou * dim + c * inner_size_ + in); + diff_src[off_diff] = dst[off_data] * (diff_dst[off_diff] - sbr); + } + } + }); +} + +template struct ref_softmax_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp new file mode 100644 index 0000000000..5cb74d8007 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp @@ -0,0 +1,186 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SOFTMAX_HPP +#define CPU_REF_SOFTMAX_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_softmax_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_softmax_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_fwd_pd_t { + using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + const int inner_size = utils::array_product( + desc()->data_desc.dims + desc()->softmax_axis + 1, + desc()->data_desc.ndims - desc()->softmax_axis - 1); + + if (inner_size > 1) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_softmax_reduction, + sizeof(data_t) * 2 * inner_size); + } + } + }; + + ref_softmax_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { + auto ndims = pd()->desc()->data_desc.ndims; + auto dims = pd()->desc()->data_desc.dims; + auto axis = pd()->desc()->softmax_axis; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->src_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < data_d.blocking_desc().inner_nblks; ++iblk) + if (data_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = inner_size_ == 1 && data_d.is_dense() + && no_axis_blocking + && data_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_forward_dense(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + + void _max(int n, const data_t *x, data_t *max_data) const; + void _sub(int n, data_t alpha, const data_t *x, data_t *y) const; + void _exp(int n, const data_t *a, data_t *r) const; + void _sum(int n, const data_t *x, data_t *sum_data) const; + void _scal(int n, data_t alpha, data_t *x) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + +template +struct ref_softmax_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_bwd_pd_t { + using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t); + + status_t init() { + bool ok = true + && !is_fwd() + && utils::everyone_is(data_type, + dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + }; + + ref_softmax_bwd_t(const pd_t *apd): cpu_primitive_t(apd) { + auto dims = pd()->desc()->diff_desc.dims; + auto axis = pd()->desc()->softmax_axis; + auto ndims = pd()->desc()->diff_desc.ndims; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->dst_md()); + const memory_desc_wrapper diff_d(pd()->diff_dst_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < diff_d.blocking_desc().inner_nblks; ++iblk) + if (diff_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = true + && inner_size_ == 1 + && diff_d == data_d + && diff_d.is_dense() + && no_axis_blocking + && diff_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp new file mode 100644 index 0000000000..3b2a75d99b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_SUM_HPP +#define REF_SUM_HPP + +#include "reorder_pd.hpp" + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_SUM_PD_T("ref:any", ref_sum_t); + + status_t init() { + bool ok = cpu_sum_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + primitive_attr_t attr; + attr.output_scales_.set(scales_[i]); + if (i != 0) attr.post_ops_.append_sum(1.0); + + reorder_pd_t *r_pd; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, dst_md()) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector reorder_pds_; + }; + + ref_sum_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_sum_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp new file mode 100644 index 0000000000..537084db91 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Common for RNN and LSTM cell execution + */ +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +using namespace rnn_utils; + +template +rnn_cell_execution_sig( + (_ref_rnn_common_t::cell_execution)) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + if (rnn_postgemm_ != nullptr) + rnn_postgemm_->execute(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + else + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} +template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) { + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + /// bwd by data on the cell + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, + 0.0, diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + + /// bwd by weights on the cell + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + + if (!rnn.merge_gemm_iter) + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, + diff_w_iter_, rnn.diff_weights_iter_ld); + + /// bwd by bias we just accumulate diffs from the gates + gates_reduction(rnn, ws_gates_, diff_bias_); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp new file mode 100644 index 0000000000..e1a61d4c62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +#define AOC array_offset_calculator +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_[0]); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + // 1. gemm Wx[0-2],x + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + + // 2. gemm Wh[0-1],h + (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb, + rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + // 3. activation zt and rt + elemwise multiplication rt,ht-1 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j); + } + }); + + // 4. gemm Wh[2],h~t + (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1], + rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0, + &(ws_gates(0, 2, 0)), rnn.gates_ws_ld); + + // 5. activation h~t + calculate ht + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) { + assert(!"GRU int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + // use state memory for intermediate computations + // TODO: use cell ws for that + float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0)); + float *hG1_ = dhG1_; + AOC dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld); + AOC hG1(hG1_, rnn.states_nld, rnn.states_ws_ld); + + // 1. calculate dG2, dG1, and part of dht-1 + // dG2^ = dh * (1 - G0) * (1 - G2^2) + // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) + // dht-1 (part) = dh * G0 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt + * one_m_square(ws_gates(i, 2, j)); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 0, j) = dG0; + ws_gates(i, 2, j) = dG2; + } + }); + + // 2. calculate intermediate d(hG1) + // d(hG1) = dG2 * W2h^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1], + rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0, + dhG1_, rnn.states_ws_ld); + + // 3. calculate dG1^ and part of dht-1 + // dG1^ = d(hG1) * h * G1 * (1 - G1) + // dht-1 (part) += d(hG1) * G1 + // h * G1 (required for dWh) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float G1 = ws_gates(i, 1, j); + diff_states_t_l(0, i, j) += dhG1(i, j) * G1; + ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1); + hG1(i, j) = G1 * h; + } + }); + + // 4. calculate diff weights + // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h) + gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_iter_ld); + gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)), + rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0, + &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld); + + // 5. calculate diff states + // dht-1 += dG1 * W1h + dG0 * W0h + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, + (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0], + rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0, + diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + // dWx += [dG0 dG1 dG2] * [x] + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld); + } + + // 6. calculate diff bias + gates_reduction(rnn, ws_gates_, diff_bias_); +} +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp new file mode 100644 index 0000000000..8dea8c90a4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU with linear before reset + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_); + AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j); + ws_gates(i, 0, j) = logistic_fwd( + ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd( + ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd( + ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + if (rnn.is_training) + ws_Wh_b(i, j) = Wh_b; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + // 1. calculate dG1 dG2 dG3 + // dG0 = (dht - G2) * dht * (1 - G0) * G0 + // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 + // dG2 = (1 - G0) * dht * (1 - G2*G2) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + float dG2 = (1.0f - ws_gates(i, 0, j)) + * one_m_square(ws_gates(i, 2, j)) * dHt; + float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 2, j) = dG2; + ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j); + ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0; + ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1; + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) { + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + if (!rnn.merge_gemm_layer) { + // dx = dG * Wx^t + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + // dWx += dG^t * x + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + // dh += dGr * Wh^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld, + 1.0, diff_states_t_l_, rnn.states_ws_ld); + + // dWh += dGr^t * h + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_layer_ld); + + // db1-3 += e * dG + // db4 += e * (r * dG2) + gates_reduction(rnn, ws_gates_, diff_bias_); + + parallel_nd(rnn.dic, [&](int j) { + for (int i = 0; i < rnn.mb; i++) { + diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j); + } + }); +} + +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp new file mode 100644 index 0000000000..a15ba00d4c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "../simple_q10n.hpp" +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j)); + + float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j) + + ws_gates(i, 0, j) * ws_gates(i, 2, j); + states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) { + ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_u8_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + auto q_d = [&](float f) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + }; + + auto deq_w = [&](acc_data_t s, int gate, int j) { + return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ? + saturate(s) * (1.f / (weights_scales[0] * data_scale)) : + saturate(s) * (1.f / (weights_scales[gate * rnn.dic + j] + * data_scale)); + }; + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float G0 = logistic_fwd( + deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j)); + float G1 = logistic_fwd( + deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j)); + float G2 = tanh_fwd( + deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j)); + float G3 = logistic_fwd( + deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j)); + float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2; + states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp)); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Ct = c_states_t_l(i, j); + /// @todo save it in the workspace in fwd pass or recompute it to + /// save bw + float tanhCt = tanh_fwd(Ct); + // we have 2 incoming diffs on Ht + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dCt = diff_states_tp1_l(1, i, j) + + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt; + + float dG1 = c_states_tm1_l(i, j) * dCt + * x_m_square(ws_gates(i, 1, j)); + float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j)); + float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j)); + float dG2 + = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j)); + + diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j); + + ws_gates(i, 0, j) = dG0; + ws_gates(i, 1, j) = dG1; + ws_gates(i, 2, j) = dG2; + ws_gates(i, 3, j) = dG3; + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp new file mode 100644 index 0000000000..4536e8dfad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution of Vanilla RNN + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_fwd(s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_bwd(dd, s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return tanh_fwd(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return dd * one_m_square(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return logistic_fwd(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return dd * x_m_square(s); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; j++) { + const float h + = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); + ws_gates(i, 0, j) = states_t_l(i, j) = h; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) { + assert(!"VANILLA RNN int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; ++j) { + const float dH = diff_states_t_lp1(rnn.n_states, i, j) + + diff_states_tp1_l(0, i, j); + auto g = ws_gates(i, 0, j); + ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp new file mode 100644 index 0000000000..b39427caf9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp @@ -0,0 +1,191 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RNN_PD_HPP +#define CPU_RNN_PD_HPP + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "rnn_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "rnn_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { + using rnn_fwd_pd_t::rnn_fwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace data_type; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + /* Int8 is supported only for packed weights */ + data_type_t weights_iter_dt = weights_iter_md_.data_type; + data_type_t weights_layer_dt = weights_layer_md_.data_type; + ok = ok && IMPLICATION( + weights_iter_dt == s8, weights_iter_md_.format_kind + == format_kind::rnn_packed); + ok = ok && IMPLICATION( + weights_layer_dt == s8, weights_layer_md_.format_kind + == format_kind::rnn_packed); + + return ok ? status::success : status::unimplemented; + } +}; + +struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { + using rnn_bwd_pd_t::rnn_bwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + if (diff_src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc)); + if (diff_weights_layer_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo)); + } + if (diff_weights_iter_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo)); + } + if (diff_dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc)); + if (with_bias() && diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo)); + if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + ok = ok && is_blocked(diff_src_layer_md_, 3) + && is_blocked(diff_dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_), + is_blocked(diff_src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&diff_dst_iter_md_), + is_blocked(diff_dst_iter_md_, 5)); + + ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_) + && rnn_utils::is_ldigo(&diff_weights_iter_md_); + ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_), + memory_desc_matches_tag(diff_bias_md_, ldgo)); + + return ok ? status::success : status::unimplemented; + } +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp new file mode 100644 index 0000000000..09445648aa --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp @@ -0,0 +1,401 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "rnn_utils.hpp" +#include "../jit_generator.hpp" +#include "../jit_uni_eltwise.hpp" +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "mkldnn_thread.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_uni_rnn_postgemm_kernel : public jit_generator { + + typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_, + void *c_states_t_l_, void *c_states_tm1_l_); + + jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){} + + virtual void init() = 0; + +template + rnn_elemwise_sig(execute) { + rnn_utils::ws_gates_aoc ws_gates(rnn, ws_gates_); + rnn_utils::bias_aoc_t bias(rnn, bias_); + rnn_utils::ws_states_aoc states_t_l(rnn, states_t_l_); + rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + // Todo: add parallelization on dic for the batch 1 case + // Assumption: the kernel runs a loop on dic elements + parallel_nd(rnn.mb, [&](int i) { + auto b_ = &bias(0, 0); + auto g_ = &ws_gates(i, 0, 0); + auto s_tl_ = &states_t_l(i, 0); + auto c_tl_ = &c_states_t_l(i, 0); + auto c_tm1l_ = &c_states_tm1_l(i, 0); + kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_); + }); + } + +protected: + kernel_t kernel_; + const rnn_utils::rnn_conf_t &rnn_; + const primitive_attr_t *attr_; +}; + +template +struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd) + + typedef typename utils::conditional::type acc_data_t; + typedef typename utils::conditional, + jit_uni_eltwise_injector_f32>::type injector_t; + + jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr) + : jit_uni_rnn_postgemm_kernel(rnn, attr){} + + void init() override { + // we use rax for both constant tables as they use the same table + sigmoid_injector_ = new injector_t(this, + alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax); + tanh_injector_ = new injector_t(this, + alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax); + generate(); + kernel_ = (kernel_t) this->getCode(); + } + +protected: + injector_t *sigmoid_injector_; + injector_t *tanh_injector_; + + // register size in bytes + using Vmm = typename jit_uni_eltwise_injector_f32::Vmm; + size_t vlen = cpu_isa_traits::vlen; + size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen; + size_t cstate_dt_size = sizeof(float); + size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float); + size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float); + size_t qscale_dt_size = sizeof(float); + size_t bias_dt_size = sizeof(float); + + void generate() { + using namespace Xbyak; + + int mask = attr_->rnn_weights_qparams_.mask_; + float *weights_scales = attr_->rnn_weights_qparams_.scales_; + float data_scale = attr_->rnn_data_qparams_.scale_; + float data_shift = attr_->rnn_data_qparams_.shift_; + + // Labels declaration + Label vector_loop_start_label, vector_loop_end_label; + Label rem_loop_start_label, rem_loop_end_label; + Label table_label; + + // Register map + Reg64 loop_cnt(r11); // loop counter + Reg64 table_reg(rbx); // table is used for data scale and shifts + Reg64 weights_scales_reg(r13); + // We skip vmm0 as it can be used by the injector for masks on sse4.2 + Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7); + + // constant table map + Address dscale_off_addr = ptr[table_reg]; + Address dshift_off_addr = ptr[table_reg + vlen]; + Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen]; + Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits::vlen]; + + // quantize from float to u8 + auto q_d = [&](Vmm f, Vmm tmp_vmm) { + uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm); + uni_vmulps(f, f, dscale_off_addr); // apply scale + uni_vaddps(f, f, dshift_off_addr); // apply shift + uni_vcvtps2dq(f, f); // convert to int32 + uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16 + uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation + // Note that the results are interleaved by 128 bit chunks, so we need to merge them together + switch (vlen) { + case 64: { //avx512 + Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx()); + uni_vmovups(tmpz, zmm_perm_mask_addr); + vpermd(fz, tmpz, fz); + break; } + case 32: { //avx + Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx()); + uni_vmovups(tmpy, ymm_perm_mask_addr); + vpermd(fy, tmpy, fy); + break; } + case 16: // sse: nothing to do + break; + default: assert(!"Unsupported case"); + }; + }; + + auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) { + if (packed) + uni_vrcpps(tmp, s); + else + uni_vrcpss(tmp, s); // prevent divide by zero + // we add one Newton iteration + uni_vmulps(s, s, tmp); + uni_vmulps(s, s, tmp); // s <- s * tmp^2 + uni_vaddps(tmp, tmp, tmp); + uni_vsubps(tmp, tmp, s); + uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2 + }; + + // dequantize from s32 to float + auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) { + // TODO: if mask is 0 precompute mul and inverse + if (mask == 0) + uni_vbroadcastss(tmp1, ptr[weights_scales_reg]); + else + uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]); + uni_vcvtdq2ps(s, s); + uni_vmulps(tmp1, tmp1, dscale_off_addr); + fast_recip(tmp1, tmp2, packed); + uni_vmulps(s, s, tmp1); + }; + + // We start code generations here + preamble(); + + // extract addresses passed as parameter +#ifdef _WIN32 + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = r10; + // Here we cannot use rbp to have initial stack pointer so we + // use rsp and offset it with the size of pushed registers in + // preamble + mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]); +#else + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = abi_param5; +#endif + + // initialize registers with addresses and constants + mov(table_reg, table_label); + mov(weights_scales_reg, size_t(weights_scales)); + // both sigmoid and tanh use the same table so load address just once in rax + sigmoid_injector_->load_table_addr(); + + mov(loop_cnt, rnn_.dic * gate_dt_size); + cmp(loop_cnt, vlen); + jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + + L(vector_loop_start_label); + { + // load G0 G1 G2 G3 + uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true); + } + + // add biases + uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0.getIdx()); + sigmoid_injector_->compute_vector(G1.getIdx()); + tanh_injector_->compute_vector(G2.getIdx()); + sigmoid_injector_->compute_vector(G3.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0 * G2 + uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1_vmm, tmp1_vmm, G1); + uni_vfmadd231ps(tmp1_vmm, G0, G2); + uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1_vmm.getIdx()); + uni_vmulps(tmp1_vmm, tmp1_vmm, G3); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, vlen); + add(addr_bias_reg, vlen); + add(addr_states_t_l_reg, vlen_dst); + add(addr_c_states_tm1_l_reg, vlen); + add(addr_c_states_t_l_reg, vlen); + if (mask != 0) + add(weights_scales_reg, vlen); + + // increment loop counter + sub(loop_cnt, vlen); + cmp(loop_cnt, vlen); + jge(vector_loop_start_label); + } + L(vector_loop_end_label); + + cmp(loop_cnt, 0); + je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + // Same code as above, we just use movuss for accessing inputs + // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar + L(rem_loop_start_label); + { + // remaping registers to Xmms + Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx()); + Xmm tmp1s_vmm(tmp1_vmm.getIdx()); + + // load G0 G1 G2 G3 + uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false); + } + + // add biases + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0s, G0s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1s, G1s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2s, G2s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3s, G3s, tmp1s_vmm); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0s.getIdx()); + sigmoid_injector_->compute_vector(G1s.getIdx()); + tanh_injector_->compute_vector(G2s.getIdx()); + sigmoid_injector_->compute_vector(G3s.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0s * G2 + uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s); + uni_vfmadd231ps(tmp1s_vmm, G0s, G2s); + uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1s_vmm.getIdx()); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, gate_dt_size); + add(addr_bias_reg, bias_dt_size); + add(addr_states_t_l_reg, hstate_dt_size); + add(addr_c_states_tm1_l_reg, cstate_dt_size); + add(addr_c_states_t_l_reg, cstate_dt_size); + if (mask != 0) + add(weights_scales_reg, qscale_dt_size); + + // increment loop counter + sub(loop_cnt, gate_dt_size); + cmp(loop_cnt, 0); + jg(rem_loop_start_label); + + } + L(rem_loop_end_label); + + postamble(); + + // Again, only one table is needed and shared between sigmoid and tanh + sigmoid_injector_->prepare_table(false); + tanh_injector_->prepare_table(true); + + L(table_label); + { + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale)); + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift)); + // perm mask for ymm + dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7); + // perm mask for zmm + dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7); + dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14); + } + } + +}; + +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; + +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp new file mode 100644 index 0000000000..ead536816c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp @@ -0,0 +1,788 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + General architecture + + for diff states, we have n_states + 1 as we have n_states diff + to propagate to the previous iteration and 1 states to propagate + to the previous layer + index 0 is dh for cell(t-1, l) to consume + index 1 is dc for cell(t-1, l) to consume + index 2 is dh for cell(t, l-1) to consume + this indexing enables to have the same indexing for states in elemwise + function + only the cell execution function should be impacted + + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "../gemm/gemm.hpp" +#include "../simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::memory_tracking::names; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template +void _ref_rnn_common_t::gates_reduction( + const rnn_conf_t &rnn, const acc_data_t *ws_gates_, + float *diff_bias_) const { + auto body = [&](int i, int k) { + for (int j = 0; j < rnn.mb; j++) + diff_bias_[i * rnn.dic + k] + += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; + }; + + // @todo block k on simd-width +#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ + /* icc 17.0 has a problem with simd collapse */ \ + && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) +#pragma omp parallel for simd collapse(2) + for (int i = 0; i < rnn.n_gates; i++) + for (int k = 0; k < rnn.dic; k++) + body(i, k); +#else + parallel_nd(rnn.n_gates, rnn.dic, body); +#endif +} + +template +rnn_gemm_sig((_ref_rnn_common_t::gemm)) { + assert(ldA * ldB * ldC != 0); + extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, + &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { + assert(!"non packed gemm is disabled for int8"); +} + +template +rnn_gemm_sig((_ref_rnn_common_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + assert(transA == 'N'); + cblas_sgemm_compute(CblasColMajor, CblasPacked, + (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, + ldB, beta, c_, ldC); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + int8_t offseta = 0, offsetb = 0; + int32_t offsetc = 0; + cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, + CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, + ldB, offsetb, beta, c_, ldC, &offsetc); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +//*************** Grid computations strategy: linear ***************// +template +rnn_grid_execution_sig( + (_ref_rnn_common_t::linear_execution)) { + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, + rnn.states_nld * rnn.states_ws_ld); + AOC ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, + rnn.gates_nld * rnn.gates_ws_ld); + AOC weights_input( + weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); + AOC weights_states( + weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); + AOC bias( + bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC diff_weights_layer(diff_weights_layer_, rnn.n_layer, + rnn.n_dir, + rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); + AOC diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, + rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); + AOC diff_bias( + diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + AOC ws_grid( + ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); + + // We run the grid of computation + for (int dir = 0; dir < rnn.n_dir; dir++) { + for (int j = 0; j < rnn.n_layer; j++) { + int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; + + if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, + rnn.mb * rnn.n_iter, rnn.slc, 1.0, + weights_input(lay, dir, 0), rnn.weights_iter_ld, + &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, + &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); + } + + for (int i = 0; i < rnn.n_iter; i++) { + int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; + (this->*cell_func)(rnn, + &(ws_states(lay + 1, dir, iter + 1, 0)), + &(ws_c_states(lay + 1, dir, iter + 1, 0)), + &(ws_diff_states(lay, dir, 0, iter, 0)), + &(weights_input(lay, dir, 0)), + &(weights_states(lay, dir, 0)), + &(bias(lay, dir, 0)), + &(ws_states(lay, dir, iter + 1, 0)), + &(ws_states(lay + 1, dir, iter, 0)), + &(ws_c_states(lay + 1, dir, iter, 0)), + &(ws_diff_states(lay + 1, dir, 0, iter, 0)), + &(ws_diff_states(lay, dir, 0, iter + 1, 0)), + &(diff_weights_layer(lay, dir, 0)), + &(diff_weights_iter(lay, dir, 0)), + &(diff_bias(lay, dir, 0)), + &(ws_gates(lay, dir, iter, 0)), + &(ws_grid(lay, dir, iter, 0)), + ws_cell_); + } + + if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, + rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), + rnn.weights_layer_ld, + (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, 0.0, + (acc_data_t *)(&(ws_diff_states( + lay, dir, rnn.n_states, 0, 0))), + rnn.states_ws_ld); + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay, dir, 1, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), + rnn.diff_weights_layer_ld); + } + if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), + rnn.diff_weights_iter_ld); + } + } + } +} + +//********* GRID computations strategy: utility functions **********// + +template +void _ref_rnn_common_t::copy_init_layer( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, + const float *__restrict diff_dst_layer_) const { + + AOC ws_states( + ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto xt_d = memory_desc_wrapper(pd()->src_md(0)); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto xxt = xt_ + xt_d.blk_off(it, b); + src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); + src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); + if (rnn.exec_dir != r2l) + for (int c = 0; c < rnn.slc; c++) + ws_l2r_ptr[c] = xxt[c]; + if (rnn.exec_dir != l2r) + for (int c = 0; c < rnn.slc; c++) + ws_r2l_ptr[c] = xxt[c]; + }); +} + +template <> +void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, + const float *diff_dst_layer_) const { + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); + + switch (rnn.exec_dir) { + case bi_concat: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[rnn.dic + s]; + } + }); + break; + case bi_sum: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case l2r: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case r2l: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x = diff_dst_layer_ + + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + default: assert(!"Unsupported direction"); break; + } +} + +/* For int8 configuration, input iteration states may be of types f32 or u8 + * Internally h_state is always stored in u8 and c_state is always stored in f32 + * If input states are of type u8 then h state is copied and c state is dequantized + * If input states are of type f32 then h state is quantized and c_state is copied + * */ +template +template +void _ref_rnn_common_t::copy_init_iter( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_c_states_, float *__restrict ws_diff_states_, + const input_data_t *__restrict firstit_states_, + const float *__restrict diff_dst_iter_) const { + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](input_data_t f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + } else + return (src_data_t)f; + }; + + const bool dequantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::u8; + auto maybe_deq = [&](input_data_t s) { + if (dequantize) + return (((float)s - data_shift) / data_scale); + else + return (float)s; + }; + auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); + if (firstit_states_) { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.sic; s++) + ws_states(lay + 1, dir, 0, b, s) = maybe_q( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 0, b, s)]); + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.sic; s++) + ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 1, b, s)]); + }); + } else { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int j = 0; j < rnn.sic; j++) { + ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; + ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; + } + }); + } +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter_) const { + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); + if (diff_dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + array_copy(&(ws_diff_states( + lay, dir, state, rnn.n_iter, b, 0)), + diff_dst_iter_ + + diff_dst_iter_d.blk_off( + lay, dir, state, b), + rnn.dic); + }); + } else { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int i) { + for (int j = 0; j < rnn.dic; j++) + ws_diff_states(lay, dir, state, rnn.n_iter, i, j) + = 0.0f; + }); + } +} + +template +template +void _ref_rnn_common_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + + auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float shift = (pd()->attr()->rnn_data_qparams_.shift_); + float scale = (pd()->attr()->rnn_data_qparams_.scale_); + + const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (dst_data_t)(((float)s - shift) / scale); + else + return (dst_data_t)s; + }; + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + if (rnn.exec_dir != r2l) { + for (int s = 0; s < rnn.dic; s++) { + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); + } + dir = 1; + } + if (rnn.exec_dir != l2r) { + for (int s = 0; s < rnn.dic; s++) + switch (rnn.exec_dir) { + case bi_sum: + dst_layer_[dst_layer_d.blk_off(it, b, s)] + += maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + break; + default: + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + } + } + }); +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + for (int s = 0; s < rnn.slc; s++) { + float *dst_addr = diff_src_layer_ + + diff_src_layer_d.blk_off( + (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, + b, dir * rnn.slc + s); + float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); + if (rnn.n_dir - 1) + res += ws_diff_states( + 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); + dst_addr[0] = res; + } + }); +} + +template +template +void _ref_rnn_common_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::u8 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](float f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + } else + return (output_data_t)f; + }; + + const bool dequantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (output_data_t)(((float)s - data_shift) / data_scale); + else + return (output_data_t)s; + }; + if (dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, + [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] + = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); + } + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] + = maybe_q(ws_c_states( + lay + 1, dir, rnn.n_iter, b, s)); + } + }); + } +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + if (diff_src_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + for (int s = 0; s < rnn.sic; s++) { + diff_src_iter_[diff_src_iter_d.blk_off( + lay, dir, state, b, s)] + = ws_diff_states(lay, dir, state, 0, b, s); + } + }); + } +} + +template +rnn_bias_prepare_sig((_ref_rnn_common_t::bias_prepare)) { + /* Original set of bias provided by the user */ + AOC b( + b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + /* Array of pointers initialized in packing */ + AOC bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC scratch_bias( + scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + + if (rnn.copy_bias) { + parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, + [&](size_t i) { scratch_bias_[i] = b_[i]; }); + } + + for (int i = 0; i < rnn.n_layer; i++) { + for (int d = 0; d < rnn.n_dir; d++) { + int offset_bias = 0; + for (int p = 0; p < rnn.n_parts_bias; p++) { + bias(i, d, p) = rnn.copy_bias + ? (float *) &scratch_bias(i, d, offset_bias) + : (float *) &b(i, d, offset_bias); + offset_bias += rnn.parts_bias[p] * rnn.dic; + } + } + } + +} + +template +rnn_bias_finalize_sig( + (_ref_rnn_common_t::bias_finalize)) { + if (rnn.dt_conf != all_f32) { + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; + for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) + for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { + size_t off = i * rnn.n_bias * rnn.dic + j; + float weights_scale + = scale_per_oc ? weights_scales[j] : weights_scales[0]; + scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) + * data_shift / (weights_scale * data_scale); + } + } +} + +template +rnn_weights_assign_sig((_ref_rnn_common_t::assign_packed_weights)) { + assert(md->format_kind == format_kind::rnn_packed); + const auto packed_desc = md->format_desc.rnn_packed_desc; + AOC weights(weights_, + rnn.n_layer, rnn.n_dir, packed_desc.n_parts); + + size_t offset_packed = 0; + for (int l = 0; l < rnn.n_layer; l++) + for (int d = 0; d < rnn.n_dir; d++) { + for (int p = 0; p < packed_desc.n_parts; p++) { + weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; + offset_packed + += packed_desc.part_pack_size[p] / sizeof(weights_data_t); + } + } +} + +template +rnn_weights_assign_sig( + (_ref_rnn_common_t::assign_weights)) { + assert(md->format_kind == format_kind::blocked); + const auto &blk = md->format_desc.blocking; + /* Original set of weights provided by the user */ + AOC w(w_, + rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); + /* Array of pointers for each part of weights */ + AOC weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); + + for (int i = 0; i < rnn.n_layer; i++) + for (int d = 0; d < rnn.n_dir; d++) { + size_t offset_weights = 0; + for (int p = 0; p < n_parts; p++) { + weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); + offset_weights += gates_per_part[p] * blk.strides[3]; + } + } +} + +//********************* Execution function *********************// +template +void _ref_rnn_common_t::execute_( + const exec_ctx_t &ctx) const { + const rnn_conf_t &rnn = this->pd()->rnn_; + auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); + auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); + auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); + auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + + auto dst_last_layer = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); + auto dst_last_iter = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); + + auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); + auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); + + auto w_layer = reinterpret_cast(layer_weights_n_comp); + auto w_iter = reinterpret_cast(iter_weights_n_comp); + auto w_iter_comp = reinterpret_cast( + iter_weights_n_comp + rnn.weights_iter_comp_offset); + auto w_layer_comp = reinterpret_cast( + layer_weights_n_comp + rnn.weights_layer_comp_offset); + + auto scratchpad = this->scratchpad(ctx); + + auto ptr_wei_layer + = scratchpad.template get(key_rnn_ptrs_wei_layer); + auto ptr_wei_iter + = scratchpad.template get(key_rnn_ptrs_wei_iter); + auto ptr_bias = + scratchpad.template get(key_rnn_ptrs_bia); + + // fetchihg buffers from the workspace + // if no workspace was provided we use the scratchpad + char *scratch_ptr = scratchpad.template get(key_rnn_space); + char *ws_ptr = nullptr; + if (rnn.use_workspace) + ws_ptr = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); + + char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; + acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); + src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); + float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); + float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); + float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); + float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); + + auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); + auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); + + auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); + auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + // Fetching extra buffers from scratchpad + float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); + + // initialize diff_states to 0 + if (aprop == prop_kind::backward) + array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); + + /* Pack(if using packed gemm API) or copy(if input arrays have bad leading + * dimension */ + (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); + + (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), + rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, + rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, + rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, + ptr_bias, bias, ws_bias); + (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), + rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, + rnn.n_parts_weights_layer, rnn.parts_weights_layer, + rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, + bias, ws_bias); + + (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); + + // we first need to copy the initial states and input into ws + copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const float *)states, diff_dst_iter); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const uint8_t *)states, diff_dst_iter); + else + assert(!"unimplemented"); + + // run the execution on the grid + (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, + ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, + diff_weights_layer, diff_weights_iter, diff_bias); + + // Finally we copy the results to the result buffers + if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, + ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) + copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, + ws_states, ws_diff_states); + else + assert(!"unimplemented"); + + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else + assert(!"unimplemented"); +}; + +/* Fix for MSVS warning C4661 */ +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); + +template struct _ref_rnn_common_t; +template struct _ref_rnn_common_t; +template struct _ref_rnn_common_t; + +#undef AOC +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp new file mode 100644 index 0000000000..6f449a9016 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp @@ -0,0 +1,328 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_RNN_HPP +#define CPU_REF_RNN_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "../cpu_isa_traits.hpp" +#include "../gemm/os_blas.hpp" + +#include "cpu_rnn_pd.hpp" +#include "../cpu_primitive.hpp" +#include "rnn_utils.hpp" +#include "jit_uni_rnn_postgemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +float activation(float s, float alpha, float cliping, float dd); + +template +struct _ref_rnn_common_t : public cpu_primitive_t { + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type weights_data_t; + typedef typename utils::conditional::type acc_data_t; + + using class_name = _ref_rnn_common_t; + + typedef rnn_elemwise_sig((class_name::*elemwise_f)); + typedef rnn_cell_execution_sig((class_name::*cell_execution_f)); + typedef rnn_grid_execution_sig((class_name::*grid_execution_f)); + + typedef rnn_gemm_sig((class_name::*gemm_t)); + typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t)); + typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t)); + typedef rnn_weights_assign_sig((class_name::*weights_assign_t)); + + using base_pd_t = + typename utils::conditional::type; + + struct pd_t : public base_pd_t { + using base_pd_t::base_pd_t; + + DECLARE_COMMON_PD_T("ref:any", class_name); + + status_t init() { + using namespace prop_kind; + using namespace utils; + using namespace format_tag; + using namespace rnn_utils; + const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind; + + data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type; + data_type_t weights_iter_dt + = this->desc()->weights_iter_desc.data_type; + data_type_t weights_layer_dt + = this->desc()->weights_layer_desc.data_type; + + bool ok = true + && one_of(cell_kind, alg_kind::vanilla_rnn, + alg_kind::vanilla_lstm, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset) + && IMPLICATION(aprop == prop_kind::forward, + one_of(this->desc()->prop_kind, forward_training, + forward_inference)) + && IMPLICATION(aprop == backward, + one_of(this->desc()->prop_kind, backward)) + && src_layer_dt == src_type + && everyone_is( + weights_type, weights_iter_dt, weights_layer_dt) + && this->set_default_params() == status::success + && this->with_bias(); + if (!ok) + return status::unimplemented; + + init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1), + this->weights_md(0), this->weights_md(1), this->dst_md(0)); + + if (rnn_.dt_conf == all_f32) + ok = ok && this->attr()->has_default_values(); + + // Set weights descriptors to desired format + memory_desc_t new_weights_layer_md = *this->weights_md(0); + CHECK(set_expected_desc(rnn_, new_weights_layer_md, false)); + if (this->weights_layer_md_.format_kind == format_kind::any) { + this->weights_layer_md_ = new_weights_layer_md; + } else if (this->weights_layer_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_layer_md_ != new_weights_layer_md) + return status::unimplemented; + } + + memory_desc_t new_weights_iter_md = *this->weights_md(1); + CHECK(set_expected_desc(rnn_, new_weights_iter_md, true)); + if (this->weights_iter_md_.format_kind == format_kind::any) { + this->weights_iter_md_ = new_weights_iter_md; + } else if (this->weights_iter_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_iter_md_ != new_weights_iter_md) + return status::unimplemented; + } + + CHECK(this->check_layout_consistency()); + + set_conf(rnn_, *this->desc(), this->weights_md(0), + this->weights_md(1), this->diff_weights_md(0), + this->diff_weights_md(1)); + + size_t scratchpad_sz{0}, ws_sz{0}; + get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz); + + // initialize the workspace if needed + if (rnn_.is_training) { + dims_t ws_dims = { (int)ws_sz }; + mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims, + data_type::u8, format_tag::x); + } + + init_scratchpad(scratchpad_sz); + + return status::success; + } + + rnn_utils::rnn_conf_t rnn_; + + private: + void init_scratchpad(size_t scratchpad_sz) { + using namespace memory_tracking::names; + auto scratchpad = this->scratchpad_registry().registrar(); + scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096); + + int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1; + int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts; + scratchpad.book(key_rnn_ptrs_wei_layer, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_wei_iter, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_bia, + sizeof(float *) * ptr_wei_sz); + } + }; + + _ref_rnn_common_t(const pd_t *apd) + : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) { + /// @todo set max_feature_size assuming that we limit the number of + /// iterations and layer to one if slc != dic and sic != dic + /// respectively + + bias_preparation_func = &class_name::bias_prepare; + bias_finalization_func = &class_name::bias_finalize; + + auto set_gemm_funcs + = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) { + if (packed_gemm) { + g = &class_name::packed_gemm; + a = &class_name::assign_packed_weights; + } else { + g = &class_name::gemm; + a = &class_name::assign_weights; + } + }; + set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func, + weights_iter_assign_func); + + set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func, + weights_layer_assign_func); + + switch (pd()->cell_kind()) { + case alg_kind::vanilla_lstm: + cell_func = &class_name::cell_execution; + if (aprop == prop_kind::forward) { + if (mayiuse(avx512_core)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + else if (mayiuse(avx2)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + else if (mayiuse(sse42)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + assert(rnn_postgemm_ != nullptr); + rnn_postgemm_->init(); + } + elemwise_func = &class_name::lstm_elemwise; + break; + case alg_kind::vanilla_rnn: // @todo switch on cell kind + cell_func = &class_name::cell_execution; + elemwise_func = &class_name::rnn_elemwise; + switch (pd()->activation_kind()) { + case alg_kind::eltwise_relu: + activation_func = &activation; + break; + case alg_kind::eltwise_tanh: + activation_func = &activation; + break; + case alg_kind::eltwise_logistic: + activation_func = &activation; + break; + default: break; + } + break; + case alg_kind::vanilla_gru: + cell_func = &class_name::cell_execution_gru; + break; + case alg_kind::gru_linear_before_reset: + cell_func = &class_name::cell_execution_gru_lbr; + elemwise_func = &class_name::gru_lbr_elemwise; + break; + default: break; + } + + grid_computation = &class_name::linear_execution; + + size_t scratchpad_size, workspace_size; + rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_, + ws_c_states_offset_, ws_diff_states_offset_, + ws_grid_comp_offset_, ws_cell_comp_offset_, + ws_bias_offset_, scratchpad_size, workspace_size); + } + + ~_ref_rnn_common_t() {} + + // typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_(ctx); + return status::success; + } + +private: + void execute_(const exec_ctx_t &ctx) const; + rnn_grid_execution_sig(linear_execution); + rnn_cell_execution_sig(cell_execution); + rnn_cell_execution_sig(cell_execution_gru); + rnn_cell_execution_sig(cell_execution_gru_lbr); + rnn_elemwise_sig(rnn_elemwise); + rnn_elemwise_sig(lstm_elemwise); + rnn_elemwise_sig(gru_lbr_elemwise); + rnn_gemm_sig(gemm); + rnn_gemm_sig(packed_gemm); + rnn_bias_prepare_sig(bias_prepare); + rnn_bias_finalize_sig(bias_finalize); + rnn_weights_assign_sig(assign_weights); + rnn_weights_assign_sig(assign_packed_weights); + + float (*activation_func)(float dd, float s, float alpha, float cliping); + + void copy_init_layer(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, + const src_data_t *xt_, const float *diff_dst_layer) const; + + template + void copy_init_iter(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter) const; + + template + void copy_res_layer(const rnn_utils::rnn_conf_t &rnn, + dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const; + + template + void copy_res_iter(const rnn_utils::rnn_conf_t &rnn, + output_data_t *dst_iter_, float *diff_src_iter, + const src_data_t *ws_states_, float *ws_c_states, + const float *ws_diff_states_) const; + + void gates_reduction(const rnn_utils::rnn_conf_t &rnn, + const acc_data_t *ws_gates_, float *diff_bias_) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + size_t ws_gates_offset_; + size_t ws_states_offset_; + size_t ws_c_states_offset_; + size_t ws_bias_offset_; + size_t ws_diff_states_offset_; + size_t ws_grid_comp_offset_; + size_t ws_cell_comp_offset_; + jit_uni_rnn_postgemm_kernel *rnn_postgemm_; + + grid_execution_f grid_computation; + cell_execution_f cell_func; + + bias_prepare_t bias_preparation_func; + bias_finalize_t bias_finalization_func; + weights_assign_t weights_layer_assign_func; + weights_assign_t weights_iter_assign_func; + + gemm_t gemm_layer_func; + gemm_t gemm_iter_func; + elemwise_f elemwise_func; +}; + +using ref_rnn_fwd_f32_t = _ref_rnn_common_t; +using ref_rnn_bwd_f32_t = _ref_rnn_common_t; +using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t; +} +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp new file mode 100644 index 0000000000..597c63e3f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp @@ -0,0 +1,380 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_RNN_REORDERS_HPP +#define CPU_RNN_REORDERS_HPP + +#include + +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "simple_q10n.hpp" +#include "cpu_reorder_pd.hpp" +#include "../gemm/os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct rnn_data_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc) + && od == id; + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign(*reorder_pd, _pd); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + + rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const size_t nelems = input_d.nelems(); + const float scale = pd()->attr()->rnn_data_qparams_.scale_; + const float shift = pd()->attr()->rnn_data_qparams_.shift_; + + parallel_nd(nelems, [&](size_t i) { + float in = (float)input[input_d.off_l(i)] * scale + shift; + output[output_d.off_l(i)] = qz_a1b0()(in); + }); + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct rnn_weights_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && od.format_kind() == format_kind::rnn_packed + && od.rnn_packed_desc().format == mkldnn_ldigo_p + && od.rnn_packed_desc().n_parts == 1 + && attr != nullptr; + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + _pd->itag_ = itag; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + format_tag_t itag_; + + private: + void init_scratchpad() { + const memory_desc_wrapper id(src_md()); + const size_t nelems = id.nelems(); + const auto &dims = id.dims(); + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + size_t quantization_size = sizeof(int8_t) * nelems; + size_t reduction_size = itag_ == ldigo + ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0] + * dims[1] * dims[3] * dims[4] + : 0; + scratchpad.book( + key_reorder_rnn_weights_quantization, quantization_size); + scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + + /* Quantize input & compute compensation */ + auto quantized = (int8_t * __restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_rnn_weights_quantization); + auto reduction = (int32_t * __restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_rnn_weights_reduction); + float *comp = reinterpret_cast( + output + output_d.rnn_packed_desc().offset_compensation); + const float *scales = pd()->attr()->rnn_weights_qparams_.scales_; + const int mask = pd()->attr()->rnn_weights_qparams_.mask_; + + if (is_igo) { + int nthr = mkldnn_get_max_threads(); + int LD_nthr = nstl::min(L * D, nthr); + int I_nthr = nstl::min(I, nthr / LD_nthr); + parallel(nthr, [&](const int ithr, const int nthr) { + int LD_ithr = -1, LD_s = -1, LD_e = -1; + int I_ithr = -1, I_s = -1, I_e = -1; + if (ithr < LD_nthr * I_nthr) { + LD_ithr = ithr % LD_nthr; + I_ithr = ithr / LD_nthr; + balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); + balance211(I, I_nthr, I_ithr, I_s, I_e); + } + int32_t *comp_ithr = reduction + I_ithr * L * D * G * O; + for (int ld = LD_s; ld < LD_e; ld++) { + for (int go = 0; go < G * O; go++) + comp_ithr[ld * G * O + go] = 0; + for (int i = I_s; i < I_e; i++) { + PRAGMA_OMP_SIMD() + for (int go = 0; go < G * O; go++) { + const float s = scales[(mask == 0) ? 0 : go]; + int8_t q = qz_b0()( + input[ld * I * G * O + i * G * O + go], s); + quantized[ld * I * G * O + i * G * O + go] + = (int32_t)q; + comp_ithr[ld * G * O + go] += (int32_t)q; + } + } + } + }); + parallel_nd(L * D * G * O, + [&](int s) { comp[s] = saturate(reduction[s]); }); + for (int i = 1; i < I_nthr; i++) { + parallel_nd(L * D * G * O, [&](int s) { + comp[s] += saturate( + reduction[i * L * D * G * O + s]); + }); + } + } else { + parallel_nd(L * D, G * O, [&](int ld, int go) { + int32_t compensation = 0; + const float s = scales[(mask == 0) ? 0 : go]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < I; i++) { + int8_t q = qz_b0()( + input[ld * G * O * I + go * I + i], s); + compensation += (int32_t)q; + quantized[ld * G * O * I + go * I + i] = q; + } + comp[ld * G * O + go] = saturate(compensation); + }); + } + + /* Pack */ + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + int n_parts = output_d.rnn_packed_desc().n_parts; + const size_t *size_packed_cell + = output_d.rnn_packed_desc().part_pack_size; + const int *parts = output_d.rnn_packed_desc().parts; + const int n = output_d.rnn_packed_desc().n; + char *to_pack = output; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = parts[p] * O; + int k_p = I; + cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix, + is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p, + &quantized[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, g, 0, 0)], + is_igo ? G * O : I, to_pack); + to_pack += size_packed_cell[p]; + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <> +struct rnn_weights_reorder_t + : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == data_type::f32 + && od.data_type() == data_type::f32 + && od.format_kind() == format_kind::rnn_packed + && utils::one_of(od.rnn_packed_desc().format, + mkldnn_ldigo_p, mkldnn_ldgoi_p) + && attr->has_default_values(); + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->itag_ = itag; + return safe_ptr_assign(*reorder_pd, _pd); + } + + format_tag_t itag_; + }; + +private: + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + /* Pack */ + bool cross_case = false + || (pd()->itag_ == format_tag::ldigo + && rnn_pdata.format == mkldnn_ldgoi_p) + || (pd()->itag_ == format_tag::ldgoi + && rnn_pdata.format == mkldnn_ldigo_p); + auto trans = cross_case ? CblasTrans : CblasNoTrans; + int n_parts = rnn_pdata.n_parts; + const size_t *size_packed_cell = rnn_pdata.part_pack_size; + const int *parts = rnn_pdata.parts; + const int n = rnn_pdata.n; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = is_igo ? parts[p] * O : I; + int k_p = is_igo ? I : parts[p] * O; + int ld = is_igo ? G * O : I; + cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n, + k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, 0, g, 0)], + ld, output); + output += size_packed_cell[p] / sizeof(float); + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp new file mode 100644 index 0000000000..1d60415cbc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp @@ -0,0 +1,426 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "rnn_utils.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace rnn_utils; +using namespace format_tag; +using namespace rnn_packed_format; +using namespace data_type; + +bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1 + && str[3] == dims[4] && str[1] == str[2] * dims[2] + && str[0] == str[1] * dims[1]; +}; + +bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1 + && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3] + && str[0] == str[1] * dims[1]; +}; + +void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d) { + rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + rnn.is_training = utils::one_of( + rd.prop_kind, prop_kind::forward_training, prop_kind::backward); + rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset; + + switch (rd.direction) { + case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break; + case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break; + case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break; + case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break; + default: break; + } + + if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(), + weights_layer_d.data_type())) + rnn.dt_conf = all_f32; + else if (dst_layer_d.data_type() == u8) { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8u8; + else + rnn.dt_conf = f32u8f32u8; + } else { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8f32; + else + rnn.dt_conf = f32u8f32f32; + } + + rnn.n_layer = weights_layer_d.dims()[0]; + rnn.n_iter = src_layer_d.dims()[0]; + rnn.n_dir = weights_layer_d.dims()[1]; + rnn.n_gates = weights_layer_d.dims()[3]; + rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc); + rnn.n_bias = rnn.n_gates + rnn.is_lbr; + rnn.mb = src_layer_d.dims()[1]; + rnn.sic = weights_iter_d.dims()[2]; + rnn.slc = weights_layer_d.dims()[2]; + rnn.dic = weights_layer_d.dims()[4]; + rnn.dlc = dst_layer_d.dims()[2]; + + rnn.gates_ld = rnn.dic * rnn.n_gates; + rnn.gates_nld = rnn.mb; + rnn.states_nld = rnn.mb; + + /* Set the correct number of weights parts */ + bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru; + rnn.n_parts_weights_layer = 1; + rnn.parts_weights_layer[0] = rnn.n_gates; + rnn.parts_weights_layer[1] = 0; + + rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1; + rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates; + rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0; + + rnn.n_parts_bias = 1; + rnn.parts_bias[0] = rnn.n_bias; + rnn.parts_bias[1] = 0; + + /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas + * and if to mergre gemm across iterations */ + bool is_int8 = rnn.dt_conf != all_f32; + rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) + || is_int8; + bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset); + rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8; + bool is_inference = !rnn.is_training; + + rnn.use_jit_gemm = !mayiuse(avx512_mic) + && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100)) + || (rnn.is_training && rnn.dic < 500)); + + /* Decide to copy bias */ + rnn.copy_bias = rnn.dt_conf != all_f32; + +#if USE_MKL_PACKED_GEMM + rnn.use_layer_packed_gemm + = (weights_layer_d.format_kind() == format_kind::any + && rnn.slc > 760 && rnn.dic > 760 && is_inference) + || is_int8; // packed gemm is the only supported option for int8 + rnn.use_iter_packed_gemm + = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760 + && rnn.dic > 760 && is_inference) + || is_int8; +#else + rnn.use_layer_packed_gemm = false; + rnn.use_iter_packed_gemm = false; +#endif + + /* Set packed gemm sizes */ + if (rnn.use_layer_packed_gemm) { + rnn.weights_layer_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_layer; p++) { + int m_p = rnn.is_fwd + ? (rnn.parts_weights_layer[p] * rnn.dic) + : rnn.slc; + int k_p = rnn.is_fwd + ? rnn.slc + : (rnn.parts_weights_layer[p] * rnn.dic); + int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_layer_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_layer_pack_size[p] = 0; +#endif + rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_layer_pack_size[p]; + } + rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size; + rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float); + } + + if (rnn.use_iter_packed_gemm) { + rnn.weights_iter_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_iter; p++) { + int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) : + rnn.sic; + int k_p = rnn.is_fwd ? rnn.sic : + (rnn.parts_weights_iter[p] * rnn.dic); + int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_iter_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_iter_pack_size[p] = 0; +#endif + rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_iter_pack_size[p]; + } + rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size; + rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float); + } + +} + +void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d) { + + /* Set leading dimensions for input weights arrays depending on input format + */ + rnn.weights_layer_is_packed + = weights_layer_d.format_kind() == format_kind::rnn_packed; + rnn.weights_iter_is_packed + = weights_iter_d.format_kind() == format_kind::rnn_packed; + + auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) { + ld = 0; nld = 0; + if (md.is_blocking_desc()) { + if (is_ldigo(md)) { + ld = (int)md.blocking_desc().strides[2]; + nld = md.dims()[2]; + } else if (is_ldgoi(md)) { + ld = (int)md.blocking_desc().strides[4]; + nld = md.dims()[3] * md.dims()[4]; + } else + assert(!"unsupported weights format"); + } + }; + set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld); + set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld); + if (!rnn.is_fwd) { + set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld, + rnn.diff_weights_layer_nld); + set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld, + rnn.diff_weights_iter_nld); + } + + int sizeof_states_dt + = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t); + rnn.states_ws_ld + = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)), + sizeof_states_dt); + rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float)); + + /* Set workspace sizes to store: + * states to copmute a pass + * diff states to copmute bwd pass (training only) + * intermediate results from the gates + */ + rnn.use_workspace = rnn.is_training; + rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir + * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt; + bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm; + rnn.ws_c_states_size = is_lstm + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb + * rnn.states_ws_ld * sizeof(float) + : 0; + rnn.ws_diff_states_size = rnn.is_training + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) + * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld + * sizeof(float) + : (size_t)0; + rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb + * rnn.gates_ws_ld * sizeof(float); + + /* set other sizes */ + rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float); + rnn.ws_cell_comp_size + = rnn.is_lbr || rnn.dt_conf != all_f32 + ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float) + : 0; + rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer + * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float); + rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic + * sizeof(float); +} + +int rnn_utils::get_good_ld(int dim, int sizeof_dt) { + // we want matrices leading dimentions to be 64-byte aligned, + // and not divisible by 256 to avoid 4K aliasing effects + int ld = rnd_up(dim, 64 / sizeof_dt); + return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld; +} + +void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_states_offset, size_t &ws_c_states_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size) { + + const size_t page_size = 4096; // 2097152; + size_t current_offset; + /* Mandatory workspaces: go to workspace if use_workspace, scratchpad + * otherwise */ + current_offset = 0; // assumes the workspace base pointer is page aligned + ws_gates_offset = current_offset; + current_offset += rnn.ws_gates_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_states_offset = current_offset; + current_offset += rnn.ws_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_c_states_offset = current_offset; + current_offset += rnn.ws_c_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_diff_states_offset = current_offset; + current_offset += rnn.ws_diff_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_grid_comp_offset = current_offset; + current_offset += rnn.ws_grid_comp_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_cell_comp_offset = current_offset; + current_offset += rnn.ws_cell_comp_size; + + workspace_size = rnn.use_workspace ? current_offset : 0; + + /* Optional scratchpads */ + // Assumes the scratchpad base pointer is page aligned. + // If use_workspace, the following goes to scratchpad alone, + // otherwise, all goes to scratchpad and continue incrementing offset + current_offset = rnn.use_workspace ? 0 : current_offset; + + if (rnn.copy_bias) { + current_offset = utils::rnd_up(current_offset, page_size); + ws_bias_offset = current_offset; + current_offset += rnn.ws_bias_size; + } + + scratchpad_size = current_offset; +} + +void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size) { + size_t ws_gates_offset, ws_states_offset, ws_c_states_offset, + ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset; + set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset, + ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset, scratchpad_size, workspace_size); +} + +status_t rnn_utils::set_good_strides( + memory_desc_t &weights_md, format_tag_t tag) { + auto &strides = weights_md.format_desc.blocking.strides; + auto dims = weights_md.dims; + + if (tag == ldigo) { + strides[2] = rnn_utils::get_good_ld((int)strides[2], + (int)types::data_type_size(weights_md.data_type)); + strides[1] = dims[2] * strides[2]; + strides[0] = dims[1] * strides[1]; + } else if (tag == ldgoi) { + strides[4] = rnn_utils::get_good_ld((int)strides[4], + (int)types::data_type_size(weights_md.data_type)); + strides[3] = dims[4] * strides[4]; + strides[1] = dims[3] * strides[3]; + strides[0] = dims[1] * strides[1]; + } else + return status::unimplemented; + + return status::success; +} + +status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, + memory_desc_t &weights_md, bool is_iter) { + using namespace format_tag; + bool use_packed_gemm = is_iter + ? rnn.use_iter_packed_gemm + : rnn.use_layer_packed_gemm; + if (use_packed_gemm) { + weights_md.format_kind = format_kind::rnn_packed; + rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc; + rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p; + if (is_iter) { + rnn_pdata.n = rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_iter; + array_copy(rnn_pdata.parts, rnn.parts_weights_iter, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset; + rnn_pdata.size = rnn.weights_iter_pack_size; + } else { + rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_layer; + array_copy(rnn_pdata.parts, rnn.parts_weights_layer, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset; + rnn_pdata.size = rnn.weights_layer_pack_size; + } + } else { + CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + // Adjust strides for good leading dimension in GEMM + CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + } + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp new file mode 100644 index 0000000000..99eb787a64 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_UTILS_HPP +#define RNN_UTILS_HPP + +#include "mkldnn.h" + +#include "cpu_rnn_pd.hpp" + + +#define rnn_elemwise_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \ + src_data_t *states_t_l_, float *c_states_t_l_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_l_, float *diff_states_t_lp1_, \ + float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \ + float *ws_cell_) const + +#define rnn_cell_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \ + float *c_states_t_l_, float *diff_states_t_l_, \ + weights_data_t **w_layer_, weights_data_t **w_iter_, \ + float **bias_, src_data_t *states_t_lm1_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_lp1_, float *diff_states_tp1_l_, \ + float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \ + acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const + +#define rnn_grid_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \ + weights_data_t **weights_states_, float **bias_, \ + src_data_t *ws_states_, float *ws_c_states_, \ + float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \ + float *ws_grid_, float *diff_weights_layer_, \ + float *diff_weights_iter_, float *diff_bias_) const + +#define rnn_gemm_sig(f) \ + void f(const char transA, const char transB, int m, int n, int k, \ + const float alpha, const weights_data_t *a_, const int ldA, \ + const src_data_t *b_, const int ldB, const float beta, \ + acc_data_t *c_, const int ldC) const + +#define rnn_bias_prepare_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \ + float *scratch_bias_) const + +#define rnn_bias_finalize_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \ + const float *w_iter_comp, const float *w_layer_comp) const + +#define rnn_weights_assign_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \ + int ld, int OC_size, int IC_size, const int n_parts, \ + const int *gates_per_part, const size_t *part_weights_pack_size, \ + weights_data_t **weights_, const weights_data_t *w_, \ + float **bias_, const float *b_, float *scratch_bias_) const + + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace rnn_utils { + +using namespace mkldnn::impl::utils; + +enum execution_direction_t { + l2r, + r2l, + bi_concat, + bi_sum, +}; + +enum data_type_conf_t { + all_f32, + u8u8u8f32, + f32u8f32f32, + u8u8u8u8, + f32u8f32u8 +}; + +struct rnn_conf_t { + execution_direction_t exec_dir; + data_type_conf_t dt_conf; + int n_layer, n_iter, n_dir, n_gates, n_states; + int mb; + int slc, sic, dic, dlc; + int gates_ld, gates_nld, gates_ws_ld; + int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS]; + int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS]; + int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS], + part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + bool weights_layer_is_packed, weights_iter_is_packed; + /* Size of packed data in bytes */ + size_t weights_layer_comp_offset, weights_layer_pack_size, + weights_iter_comp_offset, weights_iter_pack_size; + + bool copy_bias; + int weights_layer_ld, weights_layer_nld; + int diff_weights_layer_ld, diff_weights_layer_nld; + int weights_iter_ld, weights_iter_nld; + int diff_weights_iter_ld, diff_weights_iter_nld; + int states_nld, states_ws_ld; + int weights_iter_compensation_size, weights_layer_compensation_size; + bool is_fwd, is_training, is_lbr; + bool use_workspace; + + /* Size of workspace for each tensor in bytes */ + size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size, + ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size; + bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm, + use_iter_packed_gemm; +}; + +bool is_ldigo(const memory_desc_wrapper &md); +bool is_ldgoi(const memory_desc_wrapper &md); + +int get_good_ld(int dim, int sizeof_dt); + +void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d); + +void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d); + +void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_h_state_offset, size_t &ws_c_state_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size); + +void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size); +status_t set_expected_desc( + rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter); +status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); + +template +struct ws_gates_aoc { + ws_gates_aoc(const rnn_conf_t &rnn, T *data) + : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {} + T &operator()(int batch, int gate, int dic) { + return gates_(batch, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator gates_; + int DIC_; +}; +using ws_gates_aoc_t = ws_gates_aoc; +using ws_gates_aoc_s32_t = ws_gates_aoc; + +struct bias_aoc_t { + bias_aoc_t(const rnn_conf_t &rnn, const float *data) + : bias_(data, rnn.n_bias, rnn.dic) {} + const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator bias_; +}; + +template +struct ws_states_aoc { + ws_states_aoc(const rnn_conf_t &rnn, T *data) + : state_(data, rnn.states_nld, rnn.states_ws_ld) {} + T &operator()(int batch, int dic) { return state_(batch, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator state_; +}; +using ws_states_aoc_t = ws_states_aoc; +using ws_states_aoc_u8_t = ws_states_aoc; + +struct ws_diff_states_aoc_t { + ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld, + rnn.states_ws_ld) {} + float &operator()(int state_n, int batch, int dic) { + return diff_states_(state_n, 0, batch, dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator diff_states_; +}; + +struct ws_diff_w_iter_aoc_t { + ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_weights_iter_( + data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld) + , DIC_(rnn.dic) {} + float &operator()(int sic, int gate, int dic) { + return diff_weights_iter_(sic, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator diff_weights_iter_; + int DIC_; +}; +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp new file mode 100644 index 0000000000..0420f87aa5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp @@ -0,0 +1,126 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_concat.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +template +status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { + auto scratchpad = this->scratchpad(ctx); + auto iptrs = scratchpad.template get(key_concat_iptrs); + auto optrs = scratchpad.template get(key_concat_optrs); + auto nelems_to_copy = scratchpad.template get(key_concat_nelems); + auto is = scratchpad.template get(key_concat_istrides); + + const int num_arrs = pd()->n_inputs(); + const int *perm = pd()->perm_, *iperm = pd()->iperm_; + const int concat_dim = pd()->concat_dim(); + auto o_base_ptr = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + const memory_desc_wrapper o_d(pd()->src_image_md(a)); + + iptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + optrs[a] = o_base_ptr + o_d.blk_off(0); + nelems_to_copy[a] = pd()->nelems_to_concat(i_d); + for (int i = 0; i < MKLDNN_MAX_NDIMS; i++) { + if (i < perm[concat_dim]) + is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]); + else + is[a][i] = 0; + } + } + + const memory_desc_wrapper o_d(pd()->src_image_md(0)); + + strides_t os = { 0 }; + for (int i = 0; i < perm[concat_dim]; i++) + os[i] = o_d.blocking_desc().strides[iperm[i]]; + + dims_t phys_dims; + for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++) + phys_dims[i] = (i < (size_t)perm[concat_dim]) + ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] : 1; + + if (perm[concat_dim] == 0) { + for (int a = 0; a < num_arrs; ++a) { + const data_t *i = &iptrs[a][0]; + data_t *o = &optrs[a][0]; + parallel_nd((ptrdiff_t)nelems_to_copy[a], + [&](ptrdiff_t e) { o[e] = i[e]; }); + } + } else { + parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], + phys_dims[4], num_arrs, + [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, int a) { + // XXX: this code may access uninitialized values in is[*][0-4] -- + // that's why we have to set them to zero although this is + // probably benign + size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2 + + is[a][3] * n3 + is[a][4] * n4; + size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2 + + os[3] * n3 + os[4] * n4; + const data_t *i = &iptrs[a][in_off]; + data_t *o = &optrs[a][out_off]; +#if defined(__GNUC__) && !defined(__INTEL_COMPILER) + // The code below performs data copying: o[e] = i[e] + // and uses a workaround to make GNU compilers optimize it + uint8_t *ptro = reinterpret_cast(o); + const uint8_t *ptri = reinterpret_cast(i); + const dim_t main_part = + nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t); + const dim_t tail_part = + nelems_to_copy[a] % sizeof(data_t) / sizeof(uint32_t); + + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < main_part; ++e) { + *(reinterpret_cast(ptro)) + = *(reinterpret_cast(ptri)); + ptro += sizeof(uint32_t); + ptri += sizeof(uint32_t); + } + for (dim_t e = 0; e < tail_part; ++e) { + *ptro = *ptri; + ++ptro; + ++ptri; + } +#else + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e]; +#endif + }); + } + + return status::success; +} + +template struct simple_concat_t; +template struct simple_concat_t; +template struct simple_concat_t; +template struct simple_concat_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp new file mode 100644 index 0000000000..5177275452 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp @@ -0,0 +1,155 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_CONCAT_HPP +#define SIMPLE_CONCAT_HPP + +#include "memory_tracking.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct simple_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + int ndims = rhs.dst_md_.ndims; + utils::array_copy(perm_, rhs.perm_, ndims); + utils::array_copy(iperm_, rhs.iperm_, ndims); + utils::array_copy(blocks_, rhs.blocks_, ndims); + } + + DECLARE_CONCAT_PD_T("simple:any", simple_concat_t); + + status_t init() { + const memory_desc_wrapper dst_d(dst_md()); + bool ok = true + && cpu_concat_pd_t::init() == status::success + && dst_d.ndims() <= 6; + if (!ok) return status::unimplemented; + + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + const memory_desc_wrapper o_d(&src_image_mds_[i]); + + const int ignore_strides = 0; + + ok = ok + && utils::everyone_is(data_type, i_d.data_type(), + o_d.data_type()) + && utils::everyone_is(format_kind::blocked, + i_d.format_kind(), o_d.format_kind()) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + o_d.blocking_desc(), ignore_strides) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + dst_d.blocking_desc(), ignore_strides) + && !i_d.is_additional_buffer(); + if (!ok) return status::unimplemented; + } + + dst_d.compute_blocks(blocks_); + format_perm(); + + // start dim is the first dimension after which the concatenation + // would happen contiguously + const int start_dim = perm_[concat_dim()]; + + // check that contiguous part is indeed contiguous (i.e. dense) + if (nelems_to_concat(dst_d) != + dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] + * dst_d.blocking_desc().strides[concat_dim()]) + return status::unimplemented; + + // check that all inputs have the same strides for the + // contiguous part [concat_dim .. ndims] for the *major* dims. + // the block part is already checked above + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + for (int d = start_dim; d < dst_d.ndims(); ++d) { + if (dst_d.blocking_desc().strides[iperm_[d]] + != i_d.blocking_desc().strides[iperm_[d]]) + return status::unimplemented; + } + } + + init_scratchpad(); + + return status::success; + } + + int perm_[MKLDNN_MAX_NDIMS]; + int iperm_[MKLDNN_MAX_NDIMS]; + dims_t blocks_; + + dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { + const int ndims = data_d.ndims(); + + dim_t nelems = 1; + for (int i = perm_[concat_dim()]; i < ndims; i++) + nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]]; + for (int i = 0; i < ndims; i++) + nelems *= blocks_[i]; + + return nelems; + } + + private: + void format_perm() { + const memory_desc_wrapper dst_d(dst_md()); + const int ndims = dst_d.ndims(); + + strides_t strides; + utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); + for (int i = 0; i < ndims; i++) iperm_[i] = i; + + utils::simultaneous_sort(strides, iperm_, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i; + } + + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs()); + scratchpad.book(key_concat_istrides, + sizeof(strides_t) * n_inputs()); + } + }; + + simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + typedef typename prec_traits::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp new file mode 100644 index 0000000000..e6c3b8d7af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_Q10N_HPP +#define CPU_SIMPLE_Q10N_HPP + +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::math; + +template +inline out_t round_and_saturate(float f) +{ return math::saturate(out_round(f)); } + +/* Quantization with alpha == 1 and beta == 0 */ +template +struct qz_a1b0 { + out_t operator()(in_t in) + { return round_and_saturate((float)in); } +}; + +template +struct qz_a1b0::value + && !is_subset::value + >::type> { + out_t operator()(in_t in) { return math::saturate(in); } +}; + +template +struct qz_a1b0::value>::type> { + out_t operator()(in_t in) { return (out_t)in; } +}; + +/* Quantization with alpha == 1 */ +template struct qz_a1 { + out_t operator()(in_t in, out_t out, float beta) + { return round_and_saturate((float)in + beta * out); } +}; + +template struct qz_a1 { + float operator()(in_t in, float out, float beta) + { return (float)in + beta * out; } +}; + +/* Quantization with beta == 0 */ +template struct qz_b0 { + out_t operator()(in_t in, float alpha) + { return round_and_saturate(alpha * in); } +}; + +template struct qz_b0 { + float operator()(in_t in, float alpha) { return alpha * in; } +}; + +/* Quantization */ +template struct qz { + out_t operator()(in_t in, out_t out, float alpha, float beta) { + return round_and_saturate( + alpha * in + (beta ? beta * out : 0)); + } +}; + +template struct qz { + float operator()(in_t in, float out, float alpha, float beta) + { return alpha * in + (beta ? beta * out : 0); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp new file mode 100644 index 0000000000..ff845f5bd3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp @@ -0,0 +1,1022 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_REORDER_HPP +#define CPU_SIMPLE_REORDER_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "tag_traits.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_primitive.hpp" + +#include "simple_q10n.hpp" +#include "cpu_isa_traits.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::data_type; + +using bd = block_dim_t; +using ib = inner_blk_t; + +using namespace mkldnn::impl::utils; +using math::saturate; + +template +using data_t = typename prec_traits::type; + +template +using _qz_a1b0 = qz_a1b0, data_t>; + +template +using _qz = qz, data_t>; + +namespace fmt_order { + const bool keep = true; + const bool reverse = false; + const bool any = keep; +} + +namespace spec { +struct direct_copy {}; +struct direct_copy_except_dim_0 {}; +struct reference {}; +struct conv_s8s8 {}; +} + +#define SIMPLE_REORDER_TEMPL_DECL \ + impl::data_type_t type_i, impl::format_tag_t tag_i, \ + impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep +#define SIMPLE_REORDER_TEMPL_CALL \ + type_i, tag_i, type_o, tag_o, order_keep + +#define DECLARE_COMMON_PARAMS() \ + const memory_desc_wrapper &input_d = pd->src_md(); \ + const memory_desc_wrapper &output_d = pd->dst_md(); \ + const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \ + const float beta = pd->beta(); MAYBE_UNUSED(beta); + +/* specific reorders: common template */ +template +struct simple_reorder_impl {}; + +namespace { +inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i, + impl::format_tag_t tag_o, const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d) { + return input_d.matches_tag(order_keep ? tag_i : tag_o) + && output_d.matches_tag(order_keep ? tag_o : tag_i); +} +inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) { + if (many_scales_support) + return true; + return IMPLICATION(attr, attr->output_scales_.mask_ == 0); +} +} + +/* specific reorders: implementation */ +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = (input_d.dims()[tag_o == hwigo + 0]); + const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1; + + return output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = tag_o == hwigo; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int IC = dims[w_groups + 1]; + const int H = dims[w_groups + 2]; + const int W = dims[w_groups + 3]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W; + int32_t *cp = reinterpret_cast(output + offset); + + parallel_nd(G, OC, [&](int g, int oc) { + cp[g * OC + oc] = 0; + for (int ic = 0; ic < IC; ic++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = input[input_d.blk_off(g, oc, ic, h, w)]; + auto &o = output[output_d.blk_off(g, oc, ic, h, w)]; + const float s = scales[(D_mask == 1) ? 0 : g * OC + oc]; + + o = qz_b0, data_t>()( + i, s * adj_scale); + cp[g * OC + oc] -= (int32_t)o; + } + cp [g * OC + oc] *= 128; + }); + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + const int oc = (input_d.dims()[w_groups ? 1 : 0]); + const int g = w_groups ? input_d.dims()[0] : 1; + + return input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = + !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + constexpr int is_1d = + utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i); + constexpr int blksize = tag_traits::inner_blks == ib::_4b4c + ? 4 + : tag_traits::inner_blks == ib::_2c8b4c + ? 8 + : 16; + + const auto &_g_oihw_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int NB_OC = pdims[w_groups + 0] / blksize; + const int IC = dims[w_groups + 1]; + const int NB_IC = pdims[w_groups + 1] / blksize; + const int H = is_1d ? 1 : dims[w_groups + 2]; + const int W = dims[w_groups + 3 - is_1d]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t *inp, data_t *out, + int32_t *c, const float *s, const int oc_block, const int ic_block) { +# define index AB_or_BC_blk_off::inner_blks> + + for (int ic = 0; ic < ic_block; ++ic) { + for (int oc = 0; oc < oc_block; ++oc) { + const auto _g_oihw_off = + oc * _g_oihw_d.blocking_desc().strides[w_groups + 0] + + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1]; + out[index(oc, ic)] + = qz_b0, data_t>()( + inp[_g_oihw_off], s[oc] * adj_scale); + c[oc] -= (128 * (int32_t)(out[index(oc, ic)])); + } + } +# undef index + }; + + constexpr int i_mult = blksize; + constexpr int o_mult = 1; + + size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W; + int32_t *cp = reinterpret_cast(output + offset); + parallel_nd(G * NB_OC * blksize, [&](int i) { + cp[i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off(g, o, i, w) \ + : (md).blk_off(g, o, i, h, w)) + + parallel_nd(G, NB_OC, [&](int g, int O) { + for (int I = 0; I < NB_IC; I++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = &input[wei_blk_off( + input_d, g, i_mult * O, i_mult * I, h, w)]; + auto o = &output[wei_blk_off( + output_d, g, o_mult * O, o_mult * I, h, w)]; + const int oc_block = nstl::min(blksize, OC - O * blksize); + const int ic_block = nstl::min(blksize, IC - I * blksize); + + int _offset = (g * NB_OC + O) * blksize; + ker(i, o, (order_keep) ? &cp[_offset] : nullptr, + &scales[(D_mask == 1) ? 0 : _offset], + oc_block, ic_block); + } + }); + +# undef wei_blk_off + + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = input_d.dims()[1]; + const int g = input_d.dims()[0]; + + return true + && order_keep + && input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + constexpr bool is_1d = tag_i == goiw; + constexpr int blksize = 16; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + const int G = dims[0]; + const int Gp = pdims[0]; + const int OC = dims[1]; + const int IC = dims[2]; + const int H = is_1d ? 1 : dims[3]; + const int W = dims[4 - is_1d]; + + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + const float *scales = pd->attr()->output_scales_.scales_; + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t *inp, data_t *out, + int32_t *cp, const float *s, const int g_block) { + PRAGMA_OMP_SIMD() + for (int g = 0; g < g_block; g++) { + const auto i_off = g * input_d.blocking_desc().strides[0]; + out[g] = qz_b0, data_t>()( + inp[i_off], s[g * OC] * adj_scale); + cp[g * OC] -= 128 * (int32_t)(out[g]); + } + }; + + size_t cp_offset = output_d.size() - output_d.additional_buffer_size(); + int32_t *cp = reinterpret_cast(output + cp_offset); + parallel_nd((Gp/blksize) * OC, [&](int ib) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; i++) + cp[ib * blksize + i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w)) + + parallel_nd(Gp/blksize, OC, [&](int gb, int O) { + for (int I = 0; I < IC; I++) { + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) + { + const int g_block = nstl::min(G - gb * blksize, blksize); + const auto inp = &input[wei_blk_off( + input_d, gb * blksize, O, I, h, w)]; + const auto out = &output[wei_blk_off( + output_d, gb, O, I, h, w)]; + int offset = gb * blksize + O; + ker(inp, out, &cp[offset], + &scales[(D_mask == 1) ? 0 : offset], g_block); + } + } + }); + +# undef wei_blk_off + + return success; + } +}; + +/* reorders with tail support */ + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + constexpr int is_1d = tag_i == nCw8c; + constexpr int is_3d = tag_i == nCdhw8c; + constexpr int blksize_16 = 16; + constexpr int blksize_8 = 8; + constexpr int ic_mult = order_keep ? 2 : 1; + constexpr int oc_mult = order_keep ? 1 : 2; + + const auto &dims = input_d.dims(); + const auto &pdims = order_keep ? output_d.padded_dims() + : input_d.padded_dims(); + + const int C = dims[1]; + const int D = is_3d ? dims[2] : 1; + const int H = is_1d ? 1 : dims[2 + is_3d]; + const int W = dims[3 + is_3d - is_1d]; + + auto ker = [&](const data_t *i, data_t *o, + const int block_16) { + const int nb = (block_16 - 1) / blksize_8 + 1; + if (alpha == 1.0 && beta == 0.0) { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz_a1b0()( + i[i_off + c]); + } + } + } else { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz()(i[i_off + c], + o[o_off + c], alpha, beta); + } + } + } + }; + +# define data_blk_off(md, n, c, d, h, w) \ + ( is_1d ? (md).blk_off(n, c, w) \ + : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w)) + + parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W, + [&](int n, int nb_c, int d, int h, int w) { + auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)]; + auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)]; + const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16); + ker(i, o, block_16); + }); + +# undef data_blk_off + + return success; + } +}; + +#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \ + static bool is_applicable(const memory_desc_wrapper &input_d, \ + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \ + return simple_attr_check(attr, false) && (order_keep \ + ? output_d.matches_tag(tag_o) && input_d.is_plain() \ + : input_d.matches_tag(tag_o) && output_d.is_plain()); \ + } + +template +struct simple_reorder_impl::block_dims == bd::_A + || tag_traits::block_dims == bd::_B) + && tag_traits::ndims >= 3 + && tag_traits::ndims <= 6 + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &block_d = order_keep ? output_d : input_d; + const auto &dims = input_d.dims(); + const auto &pdims = block_d.padded_dims(); + + constexpr int ndims = tag_traits::ndims; + constexpr int blk_idx = tag_traits::block_dims == bd::_A ? 0 : 1; + + const dim_t H0 = dims[0]; + const dim_t H1 = dims[1]; + const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1; + const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1; + const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1; + const dim_t L = dims[ndims - 1]; + const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1]; + + constexpr int blksize = false ? 0 + : utils::one_of(tag_traits::inner_blks, ib::_4a, ib::_4b) ? 4 + : utils::one_of(tag_traits::inner_blks, ib::_8a, ib::_8b) ? 8 + : 16; + + auto ker = [&](const data_t *i, data_t *o, int block) { + if (alpha == 1.0 && beta == 0.0) { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz_a1b0()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0()( + i[l * l_blk_stride + blk]); + } + } + } else { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz()( + i[flat_off], o[l * blksize + blk], + alpha, beta); + } else { + o[flat_off] = _qz()( + i[l * l_blk_stride + blk], o[flat_off], + alpha, beta); + } + } + } + }; + +# define off(md, h0, h1, m0, m1, m2) \ + (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \ + : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \ + : ndims >= 4 ? (md).blk_off(h0, h1, m2) \ + : /* ndims >= 3 ? */ (md).blk_off(h0, h1)) + + constexpr int i_mult = order_keep ? blksize : 1; + constexpr int o_mult = order_keep ? 1 : blksize; + + if (blk_idx == 0) { + const dim_t BH0 = pdims[0] / blksize; + parallel_nd(BH0, H1, M0, M1, M2, + [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)]; + auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)]; + const int block = nstl::min(blksize, H0 - bh0 * blksize); + ker(i, o, block); + }); + } else if (blk_idx == 1) { + const dim_t BH1 = pdims[1] / blksize; + parallel_nd(H0, BH1, M0, M1, M2, + [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)]; + auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)]; + const int block = nstl::min(blksize, H1 - bh1 * blksize); + ker(i, o, block); + }); + } else { + assert(!"unimplemented"); + } + +# undef off + + return success; + } +}; + +template +struct simple_reorder_impl::block_dims == bd::_AB + || tag_traits::block_dims == bd::_BC) + && IMPLICATION(tag_traits::block_dims == bd::_AB, + tag_traits::ndims >= 3 && tag_traits::ndims <= 5) + && IMPLICATION(tag_traits::block_dims == bd::_BC, + tag_traits::ndims >= 4 && tag_traits::ndims <= 6) + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + constexpr int ndims = tag_traits::ndims; + + static constexpr bool with_g = tag_traits::block_dims == bd::_BC; + const dim_t G = with_g ? dims[0] : 1; + + const dim_t H0 = dims[0 + with_g]; + const dim_t H1 = dims[1 + with_g]; + + const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1; + const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1; + const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1; + + constexpr int blksize_0 = false ? 0 + : utils::one_of(tag_traits::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b) + ? 4 + : utils::one_of(tag_traits::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits::inner_blks, + ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c, + ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b, + ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c) + ? 16 : INT_MIN; + + constexpr int blksize_1 = utils::one_of(tag_traits::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits::inner_blks, + ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c) + ? 16 + : utils::one_of(tag_traits::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b, + ib::_16a4b, ib::_16b4c) + ? 4 + : INT_MIN; + + const dim_t NB_H0 = pdims[0 + with_g] / blksize_0; + const dim_t NB_H1 = pdims[1 + with_g] / blksize_1; + + auto ker = [&](const data_t *i, data_t *o, + const int block_h0, const int block_h1) { +# define blk_off AB_or_BC_blk_off::inner_blks> + + if (alpha == 1.0 && beta == 0.0) { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz_a1b0()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0()( + i[blk_off(h0, h1)]); + } + } + } else { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz()(i[flat_off], + o[blk_off(h0, h1)], alpha, beta); + } else { + o[flat_off] = _qz()(i[blk_off(h0, h1)], + o[flat_off], alpha, beta); + } + } + } + +# undef blk_off + }; + + constexpr int i_mult_0 = order_keep ? blksize_0 : 1; + constexpr int o_mult_0 = order_keep ? 1 : blksize_0; + + constexpr int i_mult_1 = order_keep ? blksize_1 : 1; + constexpr int o_mult_1 = order_keep ? 1 : blksize_1; + +# define off(md, g, h0, h1, m0, m1, m2) \ + (ndims >= 5 + with_g ? (md).blk_off(g, h0, h1, m0, m1, m2) \ + : ndims >= 4 + with_g ? (md).blk_off(g, h0, h1, m1, m2) \ + : /* ndims >= 3 + with_g ? */ (md).blk_off(g, h0, h1, m2)) + + parallel_nd(G, NB_H0, NB_H1, M0, M1, M2, + [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, + g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)]; + auto o = &output[off(output_d, + g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)]; + const int block_h0 = nstl::min(blksize_0, H0 - nb_h0 * blksize_0); + const int block_h1 = nstl::min(blksize_1, H1 - nb_h1 * blksize_1); + ker(i, o, block_h0, block_h1); + }); + +# undef off + + return success; + } +}; + +/* generic and direct-copy reorders */ + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 0) + && input_d.is_dense() && output_d.is_dense() + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + assert(input_d.is_dense()); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const size_t nelems = input_d.nelems(); + + constexpr int block_size = 16; + const auto num_blocks = nelems / block_size; + const auto rem_elems = nelems % block_size; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(num_blocks, nthr, ithr, start, end); + start = start * block_size; + end = end * block_size; + + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1b0, data_t>() + (input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1, data_t>() + (input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_b0, data_t>() + (input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz, data_t>() + (input[e], output[e], alpha, beta); + } + } + + if (rem_elems != 0 && ithr == nthr - 1){ + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1b0, + data_t>()(input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1, + data_t>()(input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_b0, + data_t>()(input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz, data_t>() + (input[e], output[e], alpha, beta); + } + } + } + }); + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) { + return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d); + }; + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 1) + && is_dense_no_0(input_d) && is_dense_no_0(output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const int N = input_d.dims()[0]; + const dim_t is = input_d.blocking_desc().strides[0]; + const dim_t os = output_d.blocking_desc().strides[0]; + const dim_t nelems_no_d0 = nelems_no_dim_0(input_d); + const dim_t work_amount = N * nelems_no_d0; + + if (alpha == 1.0 && beta == 0.0) { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = dim1_s + work_rem > nelems_no_d0 + ? nelems_no_d0 : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e) { + output[os * n + e] = _qz_a1b0()( + input[is * n + e]); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } else { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = + dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0 + : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e){ + output[os * n + e] = _qz()( + input[is * n + e], output[os * n + e], alpha, + beta); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } + + return success; + } + +private: + static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) { + const int ndims = data_d.ndims(); + if (ndims <= 1) return 1; + return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1); + } + + static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) { + dims_t blocks; + data_d.compute_blocks(blocks); + + const auto &blk = data_d.blocking_desc(); + + dim_t blk_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) + blk_size *= blk.inner_blks[iblk]; + + dim_t max_size = blk_size; + for (int d = 1; d < data_d.ndims(); ++d) { + max_size = nstl::max(max_size, + data_d.padded_dims()[d] / blocks[d] * blk.strides[d]); + } + + return max_size; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* supported smask: 0x0...011..10...0, + * i.e. 1 should be contiguous */ + int smask = attr ? attr->output_scales_.mask_ : 0; + for (; smask > 0 && !(smask & 0x1); smask >>= 1); + for (; smask > 0 && smask & 0x1; smask >>= 1); + return true + && input_d.is_blocking_desc() + && output_d.is_blocking_desc() + && !output_d.is_additional_buffer() + && !input_d.is_additional_buffer() + && smask == 0; + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const size_t nelems = input_d.nelems(); + + int ndims_start = 0, ndims_mask = 0; + int smask = pd->attr()->output_scales_.mask_; + for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start; + for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask; + assert(smask == 0); + + const ptrdiff_t D_start + = utils::array_product(input_d.dims(), ndims_start); + const ptrdiff_t D_mask + = utils::array_product(input_d.dims() + ndims_start, ndims_mask); + const ptrdiff_t D_rest = nelems / D_start / D_mask; + + const float *scales = pd->attr()->output_scales_.scales_; + + parallel_nd(D_start, D_mask, D_rest, + [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) { + const float scale = scales[dm]; + + const size_t e = (ds * D_mask + dm) * D_rest + dr; + const auto &i = input[input_d.off_l(e)]; + auto &o = output[output_d.off_l(e)]; + + o = _qz()(i, o, scale, beta); + }); + + return success; + } +}; + + +/* high level class declaration */ + +template +struct simple_reorder_t: public cpu_primitive_t { + struct pd_t: public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("simple:any", simple_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + bool args_ok = true + && src_md->data_type == type_i + && dst_md->data_type == type_o + && simple_reorder_impl:: + is_applicable(src_md, dst_md, attr); + if (!args_ok) + return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign(*reorder_pd, _pd); + } + }; + + simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_TO); + simple_reorder_impl::execute( + pd(), input, output); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +#undef SIMPLE_REORDER_TEMPL_DECL +#undef SIMPLE_REORDER_TEMPL_CALL + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp new file mode 100644 index 0000000000..f0947573a9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_sum.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +status_t simple_sum_t::execute(const exec_ctx_t &ctx) const { + auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper o_d(pd()->dst_md()); + output += o_d.blk_off(0); + + const int num_arrs = pd()->n_inputs(); + const data_t *input_ptrs[max_num_arrs]; + const size_t nelems = o_d.nelems(); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + input_ptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + } + + const size_t block_size = 16 * 1024 / sizeof(data_type); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + + const auto scales = pd()->scales(); + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + }); + + return status::success; +} + +template struct simple_sum_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp new file mode 100644 index 0000000000..2a0187a184 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_SUM_HPP +#define SIMPLE_SUM_HPP + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct simple_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + DECLARE_SUM_PD_T("simple:any", simple_sum_t); + + status_t init() { + const int n = n_inputs(); + + bool ok = true + && cpu_sum_pd_t::init() == status::success + && n <= max_num_arrs; + if (!ok) return status::unimplemented; + + const memory_desc_wrapper o_d(dst_md()); + ok = ok + && o_d.data_type() == data_type + && o_d.is_dense(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n; ++i) { + const memory_desc_wrapper i_d(src_md(i)); + if (i_d != o_d) return status::unimplemented; + } + + return status::success; + } + }; + + simple_sum_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + enum {max_num_arrs = 16 }; + typedef typename prec_traits::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp new file mode 100644 index 0000000000..c2082d7d62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp @@ -0,0 +1,376 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_WINO_REORDER_HPP +#define CPU_WINO_REORDER_HPP + +#include "mkldnn_thread.hpp" + +#include "simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct wino_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_tag(utils::pick(id.ndims() - 4, + format_tag::oihw, format_tag::goihw)) + && od.format_kind() == format_kind::wino + && utils::one_of(od.wino_desc().wino_format, + mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + auto &o = memory_desc_wrapper(dst_md()).wino_desc(); + size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; + size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_reorder_wino_transform_space, + sizeof(in_data_t) * transform_space_size); + scratchpad.book(key_reorder_wino_plain, + sizeof(out_data_t) * plain_size); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + const int unsign_val_in_wino_domain_ = 5; + + wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + r_ = dst_d.wino_desc().r; + w_alpha_ = dst_d.wino_desc().alpha; + wino_format_ = dst_d.wino_desc().wino_format; + + const auto &in_dims = src_d.dims(); + int groups; + int groups_offset; + if (src_d.ndims() == 5) { + groups = in_dims[0]; + groups_offset = 1; + } else { + groups = 1; + groups_offset = 0; + } + assert(groups == 1); // groups are not supported now + MAYBE_UNUSED(groups); + + or_oc_ = in_dims[0 + groups_offset]; + or_ic_ = in_dims[1 + groups_offset]; + kh_ = in_dims[2 + groups_offset]; + kw_ = in_dims[3 + groups_offset]; + + oc_ = dst_d.wino_desc().oc; + ic_ = dst_d.wino_desc().ic; + oc_block_ = dst_d.wino_desc().oc_block; + ic_block_ = dst_d.wino_desc().ic_block; + assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); + nb_oc_ = oc_ / oc_block_; + nb_ic_ = ic_ / ic_block_; + ic2_block_ = 1; + if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + ic2_block_ = dst_d.wino_desc().ic2_block; + oc2_block_ = dst_d.wino_desc().oc2_block; + assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); + + adj_scale_ = dst_d.wino_desc().adj_scale; + + size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; + size_wspace_ = r_ * w_alpha_ * oc_block_; + } + + void transform(out_data_t *__restrict tmp_wei, + const in_data_t *__restrict input, + in_data_t *__restrict wspace) const { + const memory_desc_wrapper src_d(pd()->src_md()); + + const int smask = pd()->attr()->output_scales_.mask_; + const int ndims_mask = math::ilog2q(smask + 1); + const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); + const float *__restrict scales = pd()->attr()->output_scales_.scales_; + assert(D_mask == 1 || D_mask == (size_t)oc_); + + /* transform weights to winograd domain */ + const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, + { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; + + const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, + { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, + { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, + { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, + { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, + { 0.f, 0.f, 1.f } }; + + float *__restrict g; + if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) + g = (float *)G_2x2_3x3; + else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + g = (float *)G_4x4_3x3; + else { + assert("Unknown winograd weights target layout"); + return; + } + + int Z = oc_ * ic_; + assert(r_ == kh_ && r_ == kw_); + + for (int iic = 0; iic < ic_; iic++) { + for (int ob = 0; ob < nb_oc_; ob++) { + const in_data_t *__restrict _inp + = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; + out_data_t *__restrict _out + = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; + + for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); + + for_nd(0, 1, r_, w_alpha_, oc_block_, + [&](int ih, int j, int ioc) { + for (int iw = 0; iw < r_; ++iw) { + int inp_oc = ob * oc_block_ + ioc; + int inp_ic = iic; + in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) + ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] + : 0.f; + wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] + += inp_v * g[j * r_ + iw]; + } + }); + + for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, + [&](int i, int j, int ioc) { + float t = 0; + for (int k = 0; k < r_; ++k) + t += g[i * r_ + k] + * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; + if (type_o == data_type::s8) { + const float scale = (D_mask == 1) + ? scales[0] + : scales[ob * oc_block_ + ioc]; + _out[(i * w_alpha_ + j) * Z + ioc] + = qz_b0()( + (in_data_t)t, scale * adj_scale_); + } else { + _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; + } + }); + }} + } + + void reorder_to_aaOIoi(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int32_t *__restrict dst_bias = nullptr; + if (type_o == data_type::s8) { + const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; + const size_t bias_size = w_alpha_ * w_alpha_ * oc_; + + dst_bias = (int32_t *)(output + bias_shift); + utils::array_set((int32_t *)dst_bias, 0, bias_size); + } + int index = 0; + for (int u_h = 0; u_h < w_alpha_; u_h++) { + for (int u_w = 0; u_w < w_alpha_; u_w++) { + for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { + int u_h_shift = u_h * w_alpha_ * ic_ * oc_; + int u_w_shift = u_w * ic_ * oc_; + int u_h_shift_b = u_h * w_alpha_ * oc_; + int u_w_shift_b = u_w * oc_; + int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int _i = ib * ic_block_; + int _o = ob * oc_block_; + int ic_shift = (_i + i) * oc_; + int oc_shift = (_o + o); + int ic_block_shift = ib * oc_block_ * ic_block_ + i; + int src_offset = + u_h_shift + u_w_shift + ic_shift + oc_shift; + int dst_offset = u_h_shift + u_w_shift + oc_block_shift + + ic_block_shift; + + output[dst_offset] = tmp_wei[src_offset]; + if (type_o == data_type::s8) { + int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; + if (index != unsign_val_in_wino_domain_) + dst_bias[bias_offset] + -= (128 * (int32_t)output[dst_offset]); + else + dst_bias[bias_offset] = 0; + } + }} + }); + index++; + }} + } + + void reorder_to_aaOio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, + [&](int u_h, int u_w, int ob) { + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + for (int o = 0; o < oc_block_; o++) { + int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ + + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); + + int dst_offset + = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + ob * nb_ic_ * ic_block_ * oc_block_ + + ib * ic_block_ * oc_block_ + i * oc_block_ + o; + output[dst_offset] = tmp_wei[src_offset]; + }}} + }); + } + + void reorder_to_aaOBiOo(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, + [&](int u_h, int u_w, int occ) { + for (int ib = 0; ib < nb_ic_; ib++) { + out_data_t *__restrict wei_ptr = output + + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) + * oc2_block_ * ic_block_ * oc_block_; + int wei_offset = 0; + for (int i = 0; i < ic_block_; i++) { + for (int ob2 = 0; ob2 < oc2_block_; ob2++) { + for (int o = 0; o < oc_block_; o++) { + int icp = ib * ic_block_ + i; + int ocp = + occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + wei_ptr[wei_offset + o] = tmp_wei[src_offset]; + } + wei_offset += oc_block_; + }} + } + }); + } + + void reorder_to_OBaaIBOIio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int ic_chunks = nb_ic_ / ic2_block_; + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, + [&](int occ, int u_h, int u_w) { + for (int icc = 0; icc < ic_chunks; icc++) { + for (int ob = 0; ob < oc2_block_; ob++) { + int ocp = (occ * oc2_block_ + ob) * oc_block_; + for (int ib = 0; ib < ic2_block_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int icp = (icc * ic2_block_ + ib) * ic_block_ + i; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + int wei_offset + = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) + * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ + + ib) * ic_block_ + i) * oc_block_; + for (int o = 0; o < oc_block_; o++) + output[wei_offset + o] = tmp_wei[src_offset + o]; + }} + }} + }); + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + + auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_wino_transform_space); + auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_wino_plain); + + transform(tmp_wei, input, wspace); + + /* reorder to winograd domain */ + switch (wino_format_) { + case mkldnn_wino_wei_aaOIoi: + reorder_to_aaOIoi(output, tmp_wei); break; + case mkldnn_wino_wei_aaOio: + reorder_to_aaOio(output, tmp_wei); break; + case mkldnn_wino_wei_aaOBiOo: + reorder_to_aaOBiOo(output, tmp_wei); break; + case mkldnn_wino_wei_OBaaIBOIio: + reorder_to_OBaaIBOIio(output, tmp_wei); break; + default: assert("Unknown wino format"); break; + } + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int r_, w_alpha_; + int ic_, oc_, or_ic_, or_oc_, kh_, kw_; + int oc_block_, ic_block_, oc2_block_, ic2_block_; + float adj_scale_; + int nb_oc_, nb_ic_; + mkldnn_wino_memory_format_t wino_format_; + int size_wino_wei_; + int size_wspace_; +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT new file mode 100644 index 0000000000..66b6ea55d0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT @@ -0,0 +1,47 @@ + +Copyright (c) 2007 MITSUNARI Shigeo +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. +Neither the name of the copyright owner nor the names of its contributors may +be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. +----------------------------------------------------------------------------- +ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た +す場合に限り、再頒布および使用が許可されます。 + +ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項 +を含めること。 +バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作 +権表示、本条件一覧、および下記免責条項を含めること。 +書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進 +に、著作権者の名前またはコントリビューターの名前を使用してはならない。 +本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ +れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性 +に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。 +著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを +問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で +あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、 +本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の +喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接 +損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、 +一切責任を負わないものとします。 diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h new file mode 100644 index 0000000000..cf5771332f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h @@ -0,0 +1,2658 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#pragma once +#ifndef XBYAK_XBYAK_H_ +#define XBYAK_XBYAK_H_ +/*! + @file xbyak.h + @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++ + @author herumi + @url https://github.com/herumi/xbyak + @note modified new BSD license + http://opensource.org/licenses/BSD-3-Clause +*/ +#ifndef XBYAK_NO_OP_NAMES + #if not +0 // trick to detect whether 'not' is operator or not + #error "use -fno-operator-names option if you want to use and(), or(), xor(), not() as function names, Or define XBYAK_NO_OP_NAMES and use and_(), or_(), xor_(), not_()." + #endif +#endif + +#include // for debug print +#include +#include +#include +#include +#ifndef NDEBUG +#include +#endif + +// #define XBYAK_DISABLE_AVX512 + +//#define XBYAK_USE_MMAP_ALLOCATOR +#if !defined(__GNUC__) || defined(__MINGW32__) + #undef XBYAK_USE_MMAP_ALLOCATOR +#endif + +#ifdef __GNUC__ + #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor)) +#else + #define XBYAK_GNUC_PREREQ(major, minor) 0 +#endif + +// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft. +#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\ + ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__))) + #include + #define XBYAK_STD_UNORDERED_SET std::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap + +/* + Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using + libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version). +*/ +#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#else + #include + #define XBYAK_STD_UNORDERED_SET std::set + #include + #define XBYAK_STD_UNORDERED_MAP std::map + #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap +#endif +#ifdef _WIN32 + #include + #include + #include +#elif defined(__GNUC__) + #include + #include + #include +#endif +#if !defined(_MSC_VER) || (_MSC_VER >= 1600) + #include +#endif + +#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__)) + #define XBYAK64_WIN +#elif defined(__x86_64__) + #define XBYAK64_GCC +#endif +#if !defined(XBYAK64) && !defined(XBYAK32) + #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN) + #define XBYAK64 + #else + #define XBYAK32 + #endif +#endif + +#if (__cplusplus >= 201103) || (_MSC_VER >= 1800) + #define XBYAK_VARIADIC_TEMPLATE +#endif + +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) /* remove inline function */ + #pragma warning(disable : 4786) /* identifier is too long */ + #pragma warning(disable : 4503) /* name is too long */ + #pragma warning(disable : 4127) /* constant expresison */ +#endif + +namespace Xbyak { + +enum { + DEFAULT_MAX_CODE_SIZE = 4096, + VERSION = 0x5760 /* 0xABCD = A.BC(D) */ +}; + +#ifndef MIE_INTEGER_TYPE_DEFINED +#define MIE_INTEGER_TYPE_DEFINED +#ifdef _MSC_VER + typedef unsigned __int64 uint64; + typedef __int64 sint64; +#else + typedef uint64_t uint64; + typedef int64_t sint64; +#endif +typedef unsigned int uint32; +typedef unsigned short uint16; +typedef unsigned char uint8; +#endif + +#ifndef MIE_ALIGN + #ifdef _MSC_VER + #define MIE_ALIGN(x) __declspec(align(x)) + #else + #define MIE_ALIGN(x) __attribute__((aligned(x))) + #endif +#endif +#ifndef MIE_PACK // for shufps + #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w)) +#endif + +enum { + ERR_NONE = 0, + ERR_BAD_ADDRESSING, + ERR_CODE_IS_TOO_BIG, + ERR_BAD_SCALE, + ERR_ESP_CANT_BE_INDEX, + ERR_BAD_COMBINATION, + ERR_BAD_SIZE_OF_REGISTER, + ERR_IMM_IS_TOO_BIG, + ERR_BAD_ALIGN, + ERR_LABEL_IS_REDEFINED, + ERR_LABEL_IS_TOO_FAR, + ERR_LABEL_IS_NOT_FOUND, + ERR_CODE_ISNOT_COPYABLE, + ERR_BAD_PARAMETER, + ERR_CANT_PROTECT, + ERR_CANT_USE_64BIT_DISP, + ERR_OFFSET_IS_TOO_BIG, + ERR_MEM_SIZE_IS_NOT_SPECIFIED, + ERR_BAD_MEM_SIZE, + ERR_BAD_ST_COMBINATION, + ERR_OVER_LOCAL_LABEL, // not used + ERR_UNDER_LOCAL_LABEL, + ERR_CANT_ALLOC, + ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW, + ERR_BAD_PROTECT_MODE, + ERR_BAD_PNUM, + ERR_BAD_TNUM, + ERR_BAD_VSIB_ADDRESSING, + ERR_CANT_CONVERT, + ERR_LABEL_ISNOT_SET_BY_L, + ERR_LABEL_IS_ALREADY_SET_BY_L, + ERR_BAD_LABEL_STR, + ERR_MUNMAP, + ERR_OPMASK_IS_ALREADY_SET, + ERR_ROUNDING_IS_ALREADY_SET, + ERR_K0_IS_INVALID, + ERR_EVEX_IS_INVALID, + ERR_SAE_IS_INVALID, + ERR_ER_IS_INVALID, + ERR_INVALID_BROADCAST, + ERR_INVALID_OPMASK_WITH_MEMORY, + ERR_INVALID_ZERO, + ERR_INVALID_RIP_IN_AUTO_GROW, + ERR_INVALID_MIB_ADDRESS, + ERR_INTERNAL, + ERR_X2APIC_IS_NOT_SUPPORTED +}; + +class Error : public std::exception { + int err_; +public: + explicit Error(int err) : err_(err) + { + if (err_ < 0 || err_ > ERR_INTERNAL) { + fprintf(stderr, "bad err=%d in Xbyak::Error\n", err_); + //exit(1); + } + } + operator int() const { return err_; } + const char *what() const throw() + { + static const char *errTbl[] = { + "none", + "bad addressing", + "code is too big", + "bad scale", + "esp can't be index", + "bad combination", + "bad size of register", + "imm is too big", + "bad align", + "label is redefined", + "label is too far", + "label is not found", + "code is not copyable", + "bad parameter", + "can't protect", + "can't use 64bit disp(use (void*))", + "offset is too big", + "MEM size is not specified", + "bad mem size", + "bad st combination", + "over local label", + "under local label", + "can't alloc", + "T_SHORT is not supported in AutoGrow", + "bad protect mode", + "bad pNum", + "bad tNum", + "bad vsib addressing", + "can't convert", + "label is not set by L()", + "label is already set by L()", + "bad label string", + "err munmap", + "opmask is already set", + "rounding is already set", + "k0 is invalid", + "evex is invalid", + "sae(suppress all exceptions) is invalid", + "er(embedded rounding) is invalid", + "invalid broadcast", + "invalid opmask with memory", + "invalid zero", + "invalid rip in AutoGrow", + "invalid mib address", + "internal error", + "x2APIC is not supported" + }; + assert((size_t)err_ < sizeof(errTbl) / sizeof(*errTbl)); + return errTbl[err_]; + } +}; + +inline const char *ConvertErrorToString(const Error& err) +{ + return err.what(); +} + +inline void *AlignedMalloc(size_t size, size_t alignment) +{ +#ifdef __MINGW32__ + return __mingw_aligned_malloc(size, alignment); +#elif defined(_WIN32) + return _aligned_malloc(size, alignment); +#else + void *p; + int ret = posix_memalign(&p, alignment, size); + return (ret == 0) ? p : 0; +#endif +} + +inline void AlignedFree(void *p) +{ +#ifdef __MINGW32__ + __mingw_aligned_free(p); +#elif defined(_MSC_VER) + _aligned_free(p); +#else + free(p); +#endif +} + +template +inline const To CastTo(From p) throw() +{ + return (const To)(size_t)(p); +} +namespace inner { + +static const size_t ALIGN_PAGE_SIZE = 4096; + +inline bool IsInDisp8(uint32 x) { return 0xFFFFFF80 <= x || x <= 0x7F; } +inline bool IsInInt32(uint64 x) { return ~uint64(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; } + +inline uint32 VerifyInInt32(uint64 x) +{ +#ifdef XBYAK64 + if (!IsInInt32(x)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + return static_cast(x); +} + +enum LabelMode { + LasIs, // as is + Labs, // absolute + LaddTop // (addr + top) for mov(reg, label) with AutoGrow +}; + +} // inner + +/* + custom allocator +*/ +struct Allocator { + virtual uint8 *alloc(size_t size) { return reinterpret_cast(AlignedMalloc(size, inner::ALIGN_PAGE_SIZE)); } + virtual void free(uint8 *p) { AlignedFree(p); } + virtual ~Allocator() {} + /* override to return false if you call protect() manually */ + virtual bool useProtect() const { return true; } +}; + +#ifdef XBYAK_USE_MMAP_ALLOCATOR +class MmapAllocator : Allocator { + typedef XBYAK_STD_UNORDERED_MAP SizeList; + SizeList sizeList_; +public: + uint8 *alloc(size_t size) + { + const size_t alignedSizeM1 = inner::ALIGN_PAGE_SIZE - 1; + size = (size + alignedSizeM1) & ~alignedSizeM1; +#ifdef MAP_ANONYMOUS + const int mode = MAP_PRIVATE | MAP_ANONYMOUS; +#elif defined(MAP_ANON) + const int mode = MAP_PRIVATE | MAP_ANON; +#else + #error "not supported" +#endif + void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, -1, 0); + if (p == MAP_FAILED) throw Error(ERR_CANT_ALLOC); + assert(p); + sizeList_[(uintptr_t)p] = size; + return (uint8*)p; + } + void free(uint8 *p) + { + if (p == 0) return; + SizeList::iterator i = sizeList_.find((uintptr_t)p); + if (i == sizeList_.end()) throw Error(ERR_BAD_PARAMETER); + if (munmap((void*)i->first, i->second) < 0) throw Error(ERR_MUNMAP); + sizeList_.erase(i); + } +}; +#endif + +class Address; +class Reg; + +class Operand { + static const uint8 EXT8BIT = 0x20; + unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil + unsigned int kind_:9; + unsigned int bit_:10; +protected: + unsigned int zero_:1; + unsigned int mask_:3; + unsigned int rounding_:3; + void setIdx(int idx) { idx_ = idx; } +public: + enum Kind { + NONE = 0, + MEM = 1 << 0, + REG = 1 << 1, + MMX = 1 << 2, + FPU = 1 << 3, + XMM = 1 << 4, + YMM = 1 << 5, + ZMM = 1 << 6, + OPMASK = 1 << 7, + BNDREG = 1 << 8 + }; + enum Code { +#ifdef XBYAK64 + RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15, + R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D, + R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W, + R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B, + SPL = 4, BPL, SIL, DIL, +#endif + EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI, + AX = 0, CX, DX, BX, SP, BP, SI, DI, + AL = 0, CL, DL, BL, AH, CH, DH, BH + }; + Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0) { } + Operand(int idx, Kind kind, int bit, bool ext8bit = 0) + : idx_(static_cast(idx | (ext8bit ? EXT8BIT : 0))) + , kind_(kind) + , bit_(bit) + , zero_(0), mask_(0), rounding_(0) + { + assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two + } + Kind getKind() const { return static_cast(kind_); } + int getIdx() const { return idx_ & (EXT8BIT - 1); } + bool isNone() const { return kind_ == 0; } + bool isMMX() const { return is(MMX); } + bool isXMM() const { return is(XMM); } + bool isYMM() const { return is(YMM); } + bool isZMM() const { return is(ZMM); } + bool isXMEM() const { return is(XMM | MEM); } + bool isYMEM() const { return is(YMM | MEM); } + bool isZMEM() const { return is(ZMM | MEM); } + bool isOPMASK() const { return is(OPMASK); } + bool isBNDREG() const { return is(BNDREG); } + bool isREG(int bit = 0) const { return is(REG, bit); } + bool isMEM(int bit = 0) const { return is(MEM, bit); } + bool isFPU() const { return is(FPU); } + bool isExt8bit() const { return (idx_ & EXT8BIT) != 0; } + bool isExtIdx() const { return (getIdx() & 8) != 0; } + bool isExtIdx2() const { return (getIdx() & 16) != 0; } + bool hasEvex() const { return isZMM() || isExtIdx2() || getOpmaskIdx() || getRounding(); } + bool hasRex() const { return isExt8bit() || isREG(64) || isExtIdx(); } + bool hasZero() const { return zero_; } + int getOpmaskIdx() const { return mask_; } + int getRounding() const { return rounding_; } + void setKind(Kind kind) + { + if ((kind & (XMM|YMM|ZMM)) == 0) return; + kind_ = kind; + bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512; + } + void setBit(int bit) { bit_ = bit; } + void setOpmaskIdx(int idx, bool ignore_idx0 = false) + { + if (!ignore_idx0 && idx == 0) throw Error(ERR_K0_IS_INVALID); + if (mask_) throw Error(ERR_OPMASK_IS_ALREADY_SET); + mask_ = idx; + } + void setRounding(int idx) + { + if (rounding_) throw Error(ERR_ROUNDING_IS_ALREADY_SET); + rounding_ = idx; + } + void setZero() { zero_ = true; } + // ah, ch, dh, bh? + bool isHigh8bit() const + { + if (!isBit(8)) return false; + if (isExt8bit()) return false; + const int idx = getIdx(); + return AH <= idx && idx <= BH; + } + // any bit is accetable if bit == 0 + bool is(int kind, uint32 bit = 0) const + { + return (kind == 0 || (kind_ & kind)) && (bit == 0 || (bit_ & bit)); // cf. you can set (8|16) + } + bool isBit(uint32 bit) const { return (bit_ & bit) != 0; } + uint32 getBit() const { return bit_; } + const char *toString() const + { + const int idx = getIdx(); + if (kind_ == REG) { + if (isExt8bit()) { + static const char *tbl[4] = { "spl", "bpl", "sil", "dil" }; + return tbl[idx - 4]; + } + static const char *tbl[4][16] = { + { "al", "cl", "dl", "bl", "ah", "ch", "dh", "bh", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b" }, + { "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w" }, + { "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d" }, + { "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" }, + }; + return tbl[bit_ == 8 ? 0 : bit_ == 16 ? 1 : bit_ == 32 ? 2 : 3][idx]; + } else if (isOPMASK()) { + static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" }; + return tbl[idx]; + } else if (isZMM()) { + static const char *tbl[32] = { + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31" + }; + return tbl[idx]; + } else if (isYMM()) { + static const char *tbl[32] = { + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", + "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31" + }; + return tbl[idx]; + } else if (isXMM()) { + static const char *tbl[32] = { + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", + "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", "xmm27", "xmm28", "xmm29", "xmm30", "xmm31" + }; + return tbl[idx]; + } else if (isMMX()) { + static const char *tbl[8] = { "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7" }; + return tbl[idx]; + } else if (isFPU()) { + static const char *tbl[8] = { "st0", "st1", "st2", "st3", "st4", "st5", "st6", "st7" }; + return tbl[idx]; + } else if (isBNDREG()) { + static const char *tbl[4] = { "bnd0", "bnd1", "bnd2", "bnd3" }; + return tbl[idx]; + } + throw Error(ERR_INTERNAL); + } + bool isEqualIfNotInherited(const Operand& rhs) const { return idx_ == rhs.idx_ && kind_ == rhs.kind_ && bit_ == rhs.bit_ && zero_ == rhs.zero_ && mask_ == rhs.mask_ && rounding_ == rhs.rounding_; } + bool operator==(const Operand& rhs) const; + bool operator!=(const Operand& rhs) const { return !operator==(rhs); } + const Address& getAddress() const; + const Reg& getReg() const; +}; + +class Label; + +struct Reg8; +struct Reg16; +struct Reg32; +#ifdef XBYAK64 +struct Reg64; +#endif +class Reg : public Operand { +public: + Reg() { } + Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { } + Reg changeBit(int bit) const { return Reg(getIdx(), getKind(), bit, isExt8bit()); } + uint8 getRexW() const { return isREG(64) ? 8 : 0; } + uint8 getRexR() const { return isExtIdx() ? 4 : 0; } + uint8 getRexX() const { return isExtIdx() ? 2 : 0; } + uint8 getRexB() const { return isExtIdx() ? 1 : 0; } + uint8 getRex(const Reg& base = Reg()) const + { + uint8 rex = getRexW() | getRexR() | base.getRexW() | base.getRexB(); + if (rex || isExt8bit() || base.isExt8bit()) rex |= 0x40; + return rex; + } + Reg8 cvt8() const; + Reg16 cvt16() const; + Reg32 cvt32() const; +#ifdef XBYAK64 + Reg64 cvt64() const; +#endif +}; + +inline const Reg& Operand::getReg() const +{ + assert(!isMEM()); + return static_cast(*this); +} + +struct Reg8 : public Reg { + explicit Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { } +}; + +struct Reg16 : public Reg { + explicit Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { } +}; + +struct Mmx : public Reg { + explicit Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { } +}; + +struct EvexModifierRounding { + enum { + T_RN_SAE = 1, + T_RD_SAE = 2, + T_RU_SAE = 3, + T_RZ_SAE = 4, + T_SAE = 5 + }; + explicit EvexModifierRounding(int rounding) : rounding(rounding) {} + int rounding; +}; +struct EvexModifierZero{EvexModifierZero() {}}; + +struct Xmm : public Mmx { + explicit Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { } + Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { } + Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; } + Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; } + Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; } +}; + +struct Ymm : public Xmm { + explicit Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { } + Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Zmm : public Ymm { + explicit Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { } + Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Opmask : public Reg { + explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {} +}; + +struct BoundsReg : public Reg { + explicit BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {} +}; + +templateT operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; } +templateT operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; } +templateT operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; } + +struct Fpu : public Reg { + explicit Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { } +}; + +struct Reg32e : public Reg { + explicit Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {} +}; +struct Reg32 : public Reg32e { + explicit Reg32(int idx = 0) : Reg32e(idx, 32) {} +}; +#ifdef XBYAK64 +struct Reg64 : public Reg32e { + explicit Reg64(int idx = 0) : Reg32e(idx, 64) {} +}; +struct RegRip { + sint64 disp_; + const Label* label_; + bool isAddr_; + explicit RegRip(sint64 disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {} + friend const RegRip operator+(const RegRip& r, int disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, int disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, const Label& label) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_, &label); + } + friend const RegRip operator+(const RegRip& r, const void *addr) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_ + (sint64)addr, 0, true); + } +}; +#endif + +inline Reg8 Reg::cvt8() const +{ + const int idx = getIdx(); + if (isBit(8)) return Reg8(idx, isExt8bit()); +#ifdef XBYAK32 + if (idx >= 4) throw Error(ERR_CANT_CONVERT); +#endif + return Reg8(idx, 4 <= idx && idx < 8); +} + +inline Reg16 Reg::cvt16() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg16(idx); +} + +inline Reg32 Reg::cvt32() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg32(idx); +} + +#ifdef XBYAK64 +inline Reg64 Reg::cvt64() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg64(idx); +} +#endif + +#ifndef XBYAK_DISABLE_SEGMENT +// not derived from Reg +class Segment { + int idx_; +public: + enum { + es, cs, ss, ds, fs, gs + }; + explicit Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); } + int getIdx() const { return idx_; } + const char *toString() const + { + static const char tbl[][3] = { + "es", "cs", "ss", "ds", "fs", "gs" + }; + return tbl[idx_]; + } +}; +#endif + +class RegExp { +public: +#ifdef XBYAK64 + enum { i32e = 32 | 64 }; +#else + enum { i32e = 32 }; +#endif + RegExp(size_t disp = 0) : scale_(0), disp_(disp) { } + RegExp(const Reg& r, int scale = 1) + : scale_(scale) + , disp_(0) + { + if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (scale == 0) return; + if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE); + if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index + index_ = r; + } else { + base_ = r; + } + } + bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); } + RegExp optimize() const + { + RegExp exp = *this; + // [reg * 2] => [reg + reg] + if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) { + exp.base_ = index_; + exp.scale_ = 1; + } + return exp; + } + bool operator==(const RegExp& rhs) const + { + return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_; + } + const Reg& getBase() const { return base_; } + const Reg& getIndex() const { return index_; } + int getScale() const { return scale_; } + size_t getDisp() const { return disp_; } + void verify() const + { + if (base_.getBit() >= 128) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (index_.getBit() && index_.getBit() <= 64) { + if (index_.getIdx() == Operand::ESP) throw Error(ERR_ESP_CANT_BE_INDEX); + if (base_.getBit() && base_.getBit() != index_.getBit()) throw Error(ERR_BAD_SIZE_OF_REGISTER); + } + } + friend RegExp operator+(const RegExp& a, const RegExp& b); + friend RegExp operator-(const RegExp& e, size_t disp); + uint8 getRex() const + { + uint8 rex = index_.getRexX() | base_.getRexB(); + return rex ? uint8(rex | 0x40) : 0; + } +private: + /* + [base_ + index_ * scale_ + disp_] + base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm + */ + Reg base_; + Reg index_; + int scale_; + size_t disp_; +}; + +inline RegExp operator+(const RegExp& a, const RegExp& b) +{ + if (a.index_.getBit() && b.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + RegExp ret = a; + if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; } + if (b.base_.getBit()) { + if (ret.base_.getBit()) { + if (ret.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + // base + base => base + index * 1 + ret.index_ = b.base_; + // [reg + esp] => [esp + reg] + if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_); + ret.scale_ = 1; + } else { + ret.base_ = b.base_; + } + } + ret.disp_ += b.disp_; + return ret; +} +inline RegExp operator*(const Reg& r, int scale) +{ + return RegExp(r, scale); +} +inline RegExp operator-(const RegExp& e, size_t disp) +{ + RegExp ret = e; + ret.disp_ -= disp; + return ret; +} + +// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc) +void *const AutoGrow = (void*)1; //-V566 +void *const DontSetProtectRWE = (void*)2; //-V566 + +class CodeArray { + enum Type { + USER_BUF = 1, // use userPtr(non alignment, non protect) + ALLOC_BUF, // use new(alignment, protect) + AUTO_GROW // automatically move and grow memory if necessary + }; + CodeArray(const CodeArray& rhs); + void operator=(const CodeArray&); + bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; } + struct AddrInfo { + size_t codeOffset; // position to write + size_t jmpAddr; // value to write + int jmpSize; // size of jmpAddr + inner::LabelMode mode; + AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode) + : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {} + uint64 getVal(const uint8 *top) const + { + uint64 disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top); + if (jmpSize == 4) disp = inner::VerifyInInt32(disp); + return disp; + } + }; + typedef std::list AddrInfoList; + AddrInfoList addrInfoList_; + const Type type_; +#ifdef XBYAK_USE_MMAP_ALLOCATOR + MmapAllocator defaultAllocator_; +#else + Allocator defaultAllocator_; +#endif + Allocator *alloc_; +protected: + size_t maxSize_; + uint8 *top_; + size_t size_; + bool isCalledCalcJmpAddress_; + + bool useProtect() const { return alloc_->useProtect(); } + /* + allocate new memory and copy old data to the new area + */ + void growMemory() + { + const size_t newSize = (std::max)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2); + uint8 *newTop = alloc_->alloc(newSize); + if (newTop == 0) throw Error(ERR_CANT_ALLOC); + for (size_t i = 0; i < size_; i++) newTop[i] = top_[i]; + alloc_->free(top_); + top_ = newTop; + maxSize_ = newSize; + } + /* + calc jmp address for AutoGrow mode + */ + void calcJmpAddress() + { + if (isCalledCalcJmpAddress_) return; + for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) { + uint64 disp = i->getVal(top_); + rewrite(i->codeOffset, disp, i->jmpSize); + } + isCalledCalcJmpAddress_ = true; + } +public: + enum ProtectMode { + PROTECT_RW = 0, // read/write + PROTECT_RWE = 1, // read/write/exec + PROTECT_RE = 2 // read/exec + }; + explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0) + : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF) + , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_) + , maxSize_(maxSize) + , top_(type_ == USER_BUF ? reinterpret_cast(userPtr) : alloc_->alloc((std::max)(maxSize, 1))) + , size_(0) + , isCalledCalcJmpAddress_(false) + { + if (maxSize_ > 0 && top_ == 0) throw Error(ERR_CANT_ALLOC); + if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) { + alloc_->free(top_); + throw Error(ERR_CANT_PROTECT); + } + } + virtual ~CodeArray() + { + if (isAllocType()) { + if (useProtect()) setProtectModeRW(false); + alloc_->free(top_); + } + } + bool setProtectMode(ProtectMode mode, bool throwException = true) + { + bool isOK = protect(top_, maxSize_, mode); + if (isOK) return true; + if (throwException) throw Error(ERR_CANT_PROTECT); + return false; + } + bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); } + bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); } + void resetSize() + { + size_ = 0; + addrInfoList_.clear(); + isCalledCalcJmpAddress_ = false; + } + void db(int code) + { + if (size_ >= maxSize_) { + if (type_ == AUTO_GROW) { + growMemory(); + } else { + throw Error(ERR_CODE_IS_TOO_BIG); + } + } + top_[size_++] = static_cast(code); + } + void db(const uint8 *code, size_t codeSize) + { + for (size_t i = 0; i < codeSize; i++) db(code[i]); + } + void db(uint64 code, size_t codeSize) + { + if (codeSize > 8) throw Error(ERR_BAD_PARAMETER); + for (size_t i = 0; i < codeSize; i++) db(static_cast(code >> (i * 8))); + } + void dw(uint32 code) { db(code, 2); } + void dd(uint32 code) { db(code, 4); } + void dq(uint64 code) { db(code, 8); } + const uint8 *getCode() const { return top_; } + template + const F getCode() const { return reinterpret_cast(top_); } + const uint8 *getCurr() const { return &top_[size_]; } + template + const F getCurr() const { return reinterpret_cast(&top_[size_]); } + size_t getSize() const { return size_; } + void setSize(size_t size) + { + if (size > maxSize_) throw Error(ERR_OFFSET_IS_TOO_BIG); + size_ = size; + } + void dump() const + { + const uint8 *p = getCode(); + size_t bufSize = getSize(); + size_t remain = bufSize; + for (int i = 0; i < 4; i++) { + size_t disp = 16; + if (remain < 16) { + disp = remain; + } + for (size_t j = 0; j < 16; j++) { + if (j < disp) { + printf("%02X", p[i * 16 + j]); + } + } + putchar('\n'); + remain -= disp; + if (remain == 0) { + break; + } + } + } + /* + @param offset [in] offset from top + @param disp [in] offset from the next of jmp + @param size [in] write size(1, 2, 4, 8) + */ + void rewrite(size_t offset, uint64 disp, size_t size) + { + assert(offset < maxSize_); + if (size != 1 && size != 2 && size != 4 && size != 8) throw Error(ERR_BAD_PARAMETER); + uint8 *const data = top_ + offset; + for (size_t i = 0; i < size; i++) { + data[i] = static_cast(disp >> (i * 8)); + } + } + void save(size_t offset, size_t val, int size, inner::LabelMode mode) + { + addrInfoList_.push_back(AddrInfo(offset, val, size, mode)); + } + bool isAutoGrow() const { return type_ == AUTO_GROW; } + bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; } + /** + change exec permission of memory + @param addr [in] buffer address + @param size [in] buffer size + @param protectMode [in] mode(RW/RWE/RE) + @return true(success), false(failure) + */ + static inline bool protect(const void *addr, size_t size, int protectMode) + { +#if defined(_WIN32) + const DWORD c_rw = PAGE_READWRITE; + const DWORD c_rwe = PAGE_EXECUTE_READWRITE; + const DWORD c_re = PAGE_EXECUTE_READ; + DWORD mode; +#else + const int c_rw = PROT_READ | PROT_WRITE; + const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC; + const int c_re = PROT_READ | PROT_EXEC; + int mode; +#endif + switch (protectMode) { + case PROTECT_RW: mode = c_rw; break; + case PROTECT_RWE: mode = c_rwe; break; + case PROTECT_RE: mode = c_re; break; + default: + return false; + } +#if defined(_WIN32) + DWORD oldProtect; + return VirtualProtect(const_cast(addr), size, mode, &oldProtect) != 0; +#elif defined(__GNUC__) + size_t pageSize = sysconf(_SC_PAGESIZE); + size_t iaddr = reinterpret_cast(addr); + size_t roundAddr = iaddr & ~(pageSize - static_cast(1)); +#ifndef NDEBUG + if (pageSize != 4096) fprintf(stderr, "large page(%zd) is used. not tested enough.\n", pageSize); +#endif + return mprotect(reinterpret_cast(roundAddr), size + (iaddr - roundAddr), mode) == 0; +#else + return true; +#endif + } + /** + get aligned memory pointer + @param addr [in] address + @param alignedSize [in] power of two + @return aligned addr by alingedSize + */ + static inline uint8 *getAlignedAddress(uint8 *addr, size_t alignedSize = 16) + { + return reinterpret_cast((reinterpret_cast(addr) + alignedSize - 1) & ~(alignedSize - static_cast(1))); + } +}; + +class Address : public Operand { +public: + enum Mode { + M_ModRM, + M_64bitDisp, + M_rip, + M_ripAddr + }; + Address(uint32 sizeBit, bool broadcast, const RegExp& e) + : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), broadcast_(broadcast) + { + e_.verify(); + } +#ifdef XBYAK64 + explicit Address(size_t disp) + : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), broadcast_(false){ } + Address(uint32 sizeBit, bool broadcast, const RegRip& addr) + : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), broadcast_(broadcast) { } +#endif + RegExp getRegExp(bool optimize = true) const + { + return optimize ? e_.optimize() : e_; + } + Mode getMode() const { return mode_; } + bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; } + bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax + size_t getDisp() const { return e_.getDisp(); } + uint8 getRex() const + { + if (mode_ != M_ModRM) return 0; + return getRegExp().getRex(); + } + bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset + bool isBroadcast() const { return broadcast_; } + const Label* getLabel() const { return label_; } + bool operator==(const Address& rhs) const + { + return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && broadcast_ == rhs.broadcast_; + } + bool operator!=(const Address& rhs) const { return !operator==(rhs); } + bool isVsib() const { return e_.isVsib(); } +private: + RegExp e_; + const Label* label_; + Mode mode_; + bool broadcast_; +}; + +inline const Address& Operand::getAddress() const +{ + assert(isMEM()); + return static_cast(*this); +} + +inline bool Operand::operator==(const Operand& rhs) const +{ + if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress(); + return isEqualIfNotInherited(rhs); +} + +class AddressFrame { + void operator=(const AddressFrame&); + AddressFrame(const AddressFrame&); +public: + const uint32 bit_; + const bool broadcast_; + explicit AddressFrame(uint32 bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { } + Address operator[](const RegExp& e) const + { + return Address(bit_, broadcast_, e); + } + Address operator[](const void *disp) const + { + return Address(bit_, broadcast_, RegExp(reinterpret_cast(disp))); + } +#ifdef XBYAK64 + Address operator[](uint64 disp) const { return Address(disp); } + Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); } +#endif +}; + +struct JmpLabel { + size_t endOfJmp; /* offset from top to the end address of jmp */ + int jmpSize; + inner::LabelMode mode; + size_t disp; // disp for [rip + disp] + explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0) + : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp) + { + } +}; + +class LabelManager; + +class Label { + mutable LabelManager *mgr; + mutable int id; + friend class LabelManager; +public: + Label() : mgr(0), id(0) {} + Label(const Label& rhs); + Label& operator=(const Label& rhs); + ~Label(); + void clear() { mgr = 0; id = 0; } + int getId() const { return id; } + const uint8 *getAddress() const; + + // backward compatibility + static inline std::string toStr(int num) + { + char buf[16]; +#if defined(_MSC_VER) && (_MSC_VER < 1900) + _snprintf_s +#else + snprintf +#endif + (buf, sizeof(buf), ".%08x", num); + return buf; + } +}; + +class LabelManager { + // for string label + struct SlabelVal { + size_t offset; + SlabelVal(size_t offset) : offset(offset) {} + }; + typedef XBYAK_STD_UNORDERED_MAP SlabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP SlabelUndefList; + struct SlabelState { + SlabelDefList defList; + SlabelUndefList undefList; + }; + typedef std::list StateList; + // for Label class + struct ClabelVal { + ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {} + size_t offset; + int refCount; + }; + typedef XBYAK_STD_UNORDERED_MAP ClabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP ClabelUndefList; + typedef XBYAK_STD_UNORDERED_SET LabelPtrList; + + CodeArray *base_; + // global : stateList_.front(), local : stateList_.back() + StateList stateList_; + mutable int labelId_; + ClabelDefList clabelDefList_; + ClabelUndefList clabelUndefList_; + LabelPtrList labelPtrList_; + + int getId(const Label& label) const + { + if (label.id == 0) label.id = labelId_++; + return label.id; + } + template + void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset) + { + // add label + typename DefList::value_type item(labelId, addrOffset); + std::pair ret = defList.insert(item); + if (!ret.second) throw Error(ERR_LABEL_IS_REDEFINED); + // search undefined label + for (;;) { + typename UndefList::iterator itr = undefList.find(labelId); + if (itr == undefList.end()) break; + const JmpLabel *jmp = &itr->second; + const size_t offset = jmp->endOfJmp - jmp->jmpSize; + size_t disp; + if (jmp->mode == inner::LaddTop) { + disp = addrOffset; + } else if (jmp->mode == inner::Labs) { + disp = size_t(base_->getCurr()); + } else { + disp = addrOffset - jmp->endOfJmp + jmp->disp; +#ifdef XBYAK64 + if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32)disp)) throw Error(ERR_LABEL_IS_TOO_FAR); + } + if (base_->isAutoGrow()) { + base_->save(offset, disp, jmp->jmpSize, jmp->mode); + } else { + base_->rewrite(offset, disp, jmp->jmpSize); + } + undefList.erase(itr); + } + } + template + bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const + { + typename DefList::const_iterator i = defList.find(label); + if (i == defList.end()) return false; + *offset = i->second.offset; + return true; + } + friend class Label; + void incRefCount(int id, Label *label) + { + clabelDefList_[id].refCount++; + labelPtrList_.insert(label); + } + void decRefCount(int id, Label *label) + { + labelPtrList_.erase(label); + ClabelDefList::iterator i = clabelDefList_.find(id); + if (i == clabelDefList_.end()) return; + if (i->second.refCount == 1) { + clabelDefList_.erase(id); + } else { + --i->second.refCount; + } + } + template + bool hasUndefinedLabel_inner(const T& list) const + { +#ifndef NDEBUG + for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) { + std::cerr << "undefined label:" << i->first << std::endl; + } +#endif + return !list.empty(); + } + // detach all labels linked to LabelManager + void resetLabelPtrList() + { + for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) { + (*i)->clear(); + } + labelPtrList_.clear(); + } +public: + LabelManager() + { + reset(); + } + ~LabelManager() + { + resetLabelPtrList(); + } + void reset() + { + base_ = 0; + labelId_ = 1; + stateList_.clear(); + stateList_.push_back(SlabelState()); + stateList_.push_back(SlabelState()); + clabelDefList_.clear(); + clabelUndefList_.clear(); + resetLabelPtrList(); + } + void enterLocal() + { + stateList_.push_back(SlabelState()); + } + void leaveLocal() + { + if (stateList_.size() <= 2) throw Error(ERR_UNDER_LOCAL_LABEL); + if (hasUndefinedLabel_inner(stateList_.back().undefList)) throw Error(ERR_LABEL_IS_NOT_FOUND); + stateList_.pop_back(); + } + void set(CodeArray *base) { base_ = base; } + void defineSlabel(std::string label) + { + if (label == "@b" || label == "@f") throw Error(ERR_BAD_LABEL_STR); + if (label == "@@") { + SlabelDefList& defList = stateList_.front().defList; + SlabelDefList::iterator i = defList.find("@f"); + if (i != defList.end()) { + defList.erase(i); + label = "@b"; + } else { + i = defList.find("@b"); + if (i != defList.end()) { + defList.erase(i); + } + label = "@f"; + } + } + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + define_inner(st.defList, st.undefList, label, base_->getSize()); + } + void defineClabel(Label& label) + { + define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize()); + label.mgr = this; + labelPtrList_.insert(&label); + } + void assign(Label& dst, const Label& src) + { + ClabelDefList::const_iterator i = clabelDefList_.find(src.id); + if (i == clabelDefList_.end()) throw Error(ERR_LABEL_ISNOT_SET_BY_L); + define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset); + dst.mgr = this; + labelPtrList_.insert(&dst); + } + bool getOffset(size_t *offset, std::string& label) const + { + const SlabelDefList& defList = stateList_.front().defList; + if (label == "@b") { + if (defList.find("@f") != defList.end()) { + label = "@f"; + } else if (defList.find("@b") == defList.end()) { + throw Error(ERR_LABEL_IS_NOT_FOUND); + } + } else if (label == "@f") { + if (defList.find("@f") != defList.end()) { + label = "@b"; + } + } + const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + return getOffset_inner(st.defList, offset, label); + } + bool getOffset(size_t *offset, const Label& label) const + { + return getOffset_inner(clabelDefList_, offset, getId(label)); + } + void addUndefinedLabel(const std::string& label, const JmpLabel& jmp) + { + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + st.undefList.insert(SlabelUndefList::value_type(label, jmp)); + } + void addUndefinedLabel(const Label& label, const JmpLabel& jmp) + { + clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp)); + } + bool hasUndefSlabel() const + { + for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) { + if (hasUndefinedLabel_inner(i->undefList)) return true; + } + return false; + } + bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); } + const uint8 *getCode() const { return base_->getCode(); } + bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); } +}; + +inline Label::Label(const Label& rhs) +{ + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); +} +inline Label& Label::operator=(const Label& rhs) +{ + if (id) throw Error(ERR_LABEL_IS_ALREADY_SET_BY_L); + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); + return *this; +} +inline Label::~Label() +{ + if (id && mgr) mgr->decRefCount(id, this); +} +inline const uint8* Label::getAddress() const +{ + if (mgr == 0 || !mgr->isReady()) return 0; + size_t offset; + if (!mgr->getOffset(&offset, *this)) return 0; + return mgr->getCode() + offset; +} + +class CodeGenerator : public CodeArray { +public: + enum LabelType { + T_SHORT, + T_NEAR, + T_AUTO // T_SHORT if possible + }; +private: + CodeGenerator operator=(const CodeGenerator&); // don't call +#ifdef XBYAK64 + enum { i32e = 32 | 64, BIT = 64 }; + static const size_t dummyAddr = (size_t(0x11223344) << 32) | 55667788; + typedef Reg64 NativeReg; +#else + enum { i32e = 32, BIT = 32 }; + static const size_t dummyAddr = 0x12345678; + typedef Reg32 NativeReg; +#endif + // (XMM, XMM|MEM) + static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isXMM() || op2.isMEM()); + } + // (MMX, MMX|MEM) or (XMM, XMM|MEM) + static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2) + { + return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2); + } + // (XMM, MMX|MEM) + static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isMMX() || op2.isMEM()); + } + // (MMX, XMM|MEM) + static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isMMX() && (op2.isXMM() || op2.isMEM()); + } + // (XMM, REG32|MEM) + static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM()); + } + // (REG32, XMM|MEM) + static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM()); + } + // (REG32, REG32|MEM) + static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM()); + } + void rex(const Operand& op1, const Operand& op2 = Operand()) + { + uint8 rex = 0; + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM()) std::swap(p1, p2); + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isMEM()) { + const Address& addr = p2->getAddress(); + if (BIT == 64 && addr.is32bit()) db(0x67); + rex = addr.getRex() | p1->getReg().getRex(); + } else { + // ModRM(reg, base); + rex = op2.getReg().getRex(op1.getReg()); + } + // except movsx(16bit, 32/64bit) + if ((op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e))) db(0x66); + if (rex) db(rex); + } + enum AVXtype { + // low 3 bit + T_N1 = 1, + T_N2 = 2, + T_N4 = 3, + T_N8 = 4, + T_N16 = 5, + T_N32 = 6, + T_NX_MASK = 7, + // + T_N_VL = 1 << 3, // N * (1, 2, 4) for VL + T_DUP = 1 << 4, // N = (8, 32, 64) + T_66 = 1 << 5, + T_F3 = 1 << 6, + T_F2 = 1 << 7, + T_0F = 1 << 8, + T_0F38 = 1 << 9, + T_0F3A = 1 << 10, + T_L0 = 1 << 11, + T_L1 = 1 << 12, + T_W0 = 1 << 13, + T_W1 = 1 << 14, + T_EW0 = 1 << 15, + T_EW1 = 1 << 16, + T_YMM = 1 << 17, // support YMM, ZMM + T_EVEX = 1 << 18, + T_ER_X = 1 << 19, // xmm{er} + T_ER_Y = 1 << 20, // ymm{er} + T_ER_Z = 1 << 21, // zmm{er} + T_SAE_X = 1 << 22, // xmm{sae} + T_SAE_Y = 1 << 23, // ymm{sae} + T_SAE_Z = 1 << 24, // zmm{sae} + T_MUST_EVEX = 1 << 25, // contains T_EVEX + T_B32 = 1 << 26, // m32bcst + T_B64 = 1 << 27, // m64bcst + T_M_K = 1 << 28, // mem{k} + T_VSIB = 1 << 29, + T_MEM_EVEX = 1 << 30, // use evex if mem + T_XXX + }; + void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false) + { + int w = (type & T_W1) ? 1 : 0; + bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM(); + bool r = reg.isExtIdx(); + bool b = base.isExtIdx(); + int idx = v ? v->getIdx() : 0; + if ((idx | reg.getIdx() | base.getIdx()) >= 16) throw Error(ERR_BAD_COMBINATION); + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + uint32 vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp; + if (!b && !x && !w && (type & T_0F)) { + db(0xC5); db((r ? 0 : 0x80) | vvvv); + } else { + uint32 mmmm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv); + } + db(code); + } + void verifySAE(const Reg& r, int type) const + { + if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return; + throw Error(ERR_SAE_IS_INVALID); + } + void verifyER(const Reg& r, int type) const + { + if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return; + throw Error(ERR_ER_IS_INVALID); + } + // (a, b, c) contains non zero two or three values then err + int verifyDuplicate(int a, int b, int c, int err) + { + int v = a | b | c; + if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) return Error(err); + return v; + } + int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32 VL = 0, bool Hi16Vidx = false) + { + if (!(type & (T_EVEX | T_MUST_EVEX))) throw Error(ERR_EVEX_IS_INVALID); + int w = (type & T_EW1) ? 1 : 0; + uint32 mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + + int idx = v ? v->getIdx() : 0; + uint32 vvvv = ~idx; + + bool R = !reg.isExtIdx(); + bool X = x ? false : !base.isExtIdx2(); + bool B = !base.isExtIdx(); + bool Rp = !reg.isExtIdx2(); + int LL; + int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET); + int disp8N = 1; + if (rounding) { + if (rounding == EvexModifierRounding::T_SAE) { + verifySAE(base, type); LL = 0; + } else { + verifyER(base, type); LL = rounding - 1; + } + b = true; + } else { + if (v) VL = (std::max)(VL, v->getBit()); + VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL); + LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0; + if (b) { + disp8N = (type & T_B32) ? 4 : 8; + } else if (type & T_DUP) { + disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64; + } else { + if ((type & (T_NX_MASK | T_N_VL)) == 0) { + type |= T_N16 | T_N_VL; // default + } + int low = type & T_NX_MASK; + if (low > 0) { + disp8N = 1 << (low - 1); + if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1); + } + } + } + bool Vp = !((v ? v->isExtIdx2() : 0) | Hi16Vidx); + bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false); + if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET); + db(0x62); + db((R ? 0x80 : 0) | (X ? 0x40 : 0) | (B ? 0x20 : 0) | (Rp ? 0x10 : 0) | (mm & 3)); + db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | 4 | (pp & 3)); + db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (Vp ? 8 : 0) | (aaa & 7)); + db(code); + return disp8N; + } + void setModRM(int mod, int r1, int r2) + { + db(static_cast((mod << 6) | ((r1 & 7) << 3) | (r2 & 7))); + } + void setSIB(const RegExp& e, int reg, int disp8N = 0) + { + size_t disp64 = e.getDisp(); +#ifdef XBYAK64 + size_t high = disp64 >> 32; + if (high != 0 && high != 0xFFFFFFFF) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + uint32 disp = static_cast(disp64); + const Reg& base = e.getBase(); + const Reg& index = e.getIndex(); + const int baseIdx = base.getIdx(); + const int baseBit = base.getBit(); + const int indexBit = index.getBit(); + enum { + mod00 = 0, mod01 = 1, mod10 = 2 + }; + int mod = mod10; // disp32 + if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) { + mod = mod00; + } else { + if (disp8N == 0) { + if (inner::IsInDisp8(disp)) { + mod = mod01; + } + } else { + // disp must be casted to signed + uint32 t = static_cast(static_cast(disp) / disp8N); + if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) { + disp = t; + mod = mod01; + } + } + } + const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP; + /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */ + bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP; +#ifdef XBYAK64 + if (!baseBit && !indexBit) hasSIB = true; +#endif + if (hasSIB) { + setModRM(mod, reg, Operand::ESP); + /* SIB = [2:3:3] = [SS:index:base(=rm)] */ + const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP; + const int scale = e.getScale(); + const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0; + setModRM(SS, idx, newBaseIdx); + } else { + setModRM(mod, reg, newBaseIdx); + } + if (mod == mod01) { + db(disp); + } else if (mod == mod10 || (mod == mod00 && !baseBit)) { + dd(disp); + } + } + LabelManager labelMgr_; + bool isInDisp16(uint32 x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; } + void opModR(const Reg& reg1, const Reg& reg2, int code0, int code1 = NONE, int code2 = NONE) + { + rex(reg2, reg1); + db(code0 | (reg1.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + setModRM(3, reg1.getIdx(), reg2.getIdx()); + } + void opModM(const Address& addr, const Reg& reg, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + rex(addr, reg); + db(code0 | (reg.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + opAddr(addr, reg.getIdx(), immSize); + } + void opMIB(const Address& addr, const Reg& reg, int code0, int code1) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + if (addr.getMode() != Address::M_ModRM) throw Error(ERR_INVALID_MIB_ADDRESS); + if (BIT == 64 && addr.is32bit()) db(0x67); + const RegExp& regExp = addr.getRegExp(false); + uint8 rex = regExp.getRex(); + if (rex) db(rex); + db(code0); db(code1); + setSIB(regExp, reg.getIdx()); + } + void makeJmp(uint32 disp, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + const int shortJmpSize = 2; + const int longHeaderSize = longPref ? 2 : 1; + const int longJmpSize = longHeaderSize + 4; + if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) { + db(shortCode); db(disp - shortJmpSize); + } else { + if (type == T_SHORT) throw Error(ERR_LABEL_IS_TOO_FAR); + if (longPref) db(longPref); + db(longCode); dd(disp - longJmpSize); + } + } + template + void opJmp(T& label, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */ + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { /* label exists */ + makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref); + } else { + int jmpSize = 0; + if (type == T_NEAR) { + jmpSize = 4; + if (longPref) db(longPref); + db(longCode); dd(0); + } else { + jmpSize = 1; + db(shortCode); db(0); + } + JmpLabel jmp(size_, jmpSize, inner::LasIs); + labelMgr_.addUndefinedLabel(label, jmp); + } + } + void opJmpAbs(const void *addr, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref = 0) + { + if (isAutoGrow()) { + if (type != T_NEAR) throw Error(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW); + if (size_ + 16 >= maxSize_) growMemory(); + if (longPref) db(longPref); + db(longCode); + dd(0); + save(size_ - 4, size_t(addr) - size_, 4, inner::Labs); + } else { + makeJmp(inner::VerifyInInt32(reinterpret_cast(addr) - getCurr()), type, shortCode, longCode, longPref); + } + + } + // reg is reg field of ModRM + // immSize is the size for immediate value + // disp8N = 0(normal), disp8N = 1(force disp32), disp8N = {2, 4, 8} ; compressed displacement + void opAddr(const Address &addr, int reg, int immSize = 0, int disp8N = 0, bool permitVisb = false) + { + if (!permitVisb && addr.isVsib()) throw Error(ERR_BAD_VSIB_ADDRESSING); + if (addr.getMode() == Address::M_ModRM) { + setSIB(addr.getRegExp(), reg, disp8N); + } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) { + setModRM(0, reg, 5); + if (addr.getLabel()) { // [rip + Label] + putL_inner(*addr.getLabel(), true, addr.getDisp() - immSize); + } else { + size_t disp = addr.getDisp(); + if (addr.getMode() == Address::M_ripAddr) { + if (isAutoGrow()) throw Error(ERR_INVALID_RIP_IN_AUTO_GROW); + disp -= (size_t)getCurr() + 4 + immSize; + } + dd(inner::VerifyInInt32(disp)); + } + } + } + /* preCode is for SSSE3/SSE4 */ + void opGen(const Operand& reg, const Operand& op, int code, int pref, bool isValid(const Operand&, const Operand&), int imm8 = NONE, int preCode = NONE) + { + if (isValid && !isValid(reg, op)) throw Error(ERR_BAD_COMBINATION); + if (pref != NONE) db(pref); + if (op.isMEM()) { + opModM(op.getAddress(), reg.getReg(), 0x0F, preCode, code, (imm8 != NONE) ? 1 : 0); + } else { + opModR(reg.getReg(), op.getReg(), 0x0F, preCode, code); + } + if (imm8 != NONE) db(imm8); + } + void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext) + { + if (mmx.isXMM()) db(0x66); + opModR(Reg32(ext), mmx, 0x0F, code); + db(imm8); + } + void opMMX(const Mmx& mmx, const Operand& op, int code, int pref = 0x66, int imm8 = NONE, int preCode = NONE) + { + opGen(mmx, op, code, mmx.isXMM() ? pref : NONE, isXMMorMMX_MEM, imm8, preCode); + } + void opMovXMM(const Operand& op1, const Operand& op2, int code, int pref) + { + if (pref != NONE) db(pref); + if (op1.isXMM() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), 0x0F, code); + } else if (op1.isMEM() && op2.isXMM()) { + opModM(op1.getAddress(), op2.getReg(), 0x0F, code | 1); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false) + { + if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */ + if (mmx.isXMM()) db(0x66); + opModR(op.getReg(), mmx, 0x0F, 0xC5); db(imm); + } else { + opGen(mmx, op, code, 0x66, isXMM_REG32orMEM, imm, 0x3A); + } + } + void opR_ModM(const Operand& op, int bit, int ext, int code0, int code1 = NONE, int code2 = NONE, bool disableRex = false, int immSize = 0) + { + int opBit = op.getBit(); + if (disableRex && opBit == 64) opBit = 32; + if (op.isREG(bit)) { + opModR(Reg(ext, Operand::REG, opBit), op.getReg().changeBit(opBit), code0, code1, code2); + } else if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, opBit), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShift(const Operand& op, int imm, int ext) + { + verifyMemHasSize(op); + opR_ModM(op, 0, ext, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), NONE, NONE, false, (imm != 1) ? 1 : 0); + if (imm != 1) db(imm); + } + void opShift(const Operand& op, const Reg8& _cl, int ext) + { + if (_cl.getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opR_ModM(op, 0, ext, 0xD2); + } + void opModRM(const Operand& op1, const Operand& op2, bool condR, bool condM, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (condR) { + opModR(op1.getReg(), op2.getReg(), code0, code1, code2); + } else if (condM) { + opModM(op2.getAddress(), op1.getReg(), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShxd(const Operand& op, const Reg& reg, uint8 imm, int code, const Reg8 *_cl = 0) + { + if (_cl && _cl->getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opModRM(reg, op, (op.isREG(16 | i32e) && op.getBit() == reg.getBit()), op.isMEM() && (reg.isREG(16 | i32e)), 0x0F, code | (_cl ? 1 : 0), NONE, _cl ? 0 : 1); + if (!_cl) db(imm); + } + // (REG, REG|MEM), (MEM, REG) + void opRM_RM(const Operand& op1, const Operand& op2, int code) + { + if (op1.isREG() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), code | 2); + } else { + opModRM(op2, op1, op1.isREG() && op1.getKind() == op2.getKind(), op1.isMEM() && op2.isREG(), code); + } + } + // (REG|MEM, IMM) + void opRM_I(const Operand& op, uint32 imm, int code, int ext) + { + verifyMemHasSize(op); + uint32 immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32; + if (op.isBit(8)) immBit = 8; + if (op.getBit() < immBit) throw Error(ERR_IMM_IS_TOO_BIG); + if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */ + if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al + rex(op); + db(code | 4 | (immBit == 8 ? 0 : 1)); + } else { + int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0; + opR_ModM(op, 0, ext, 0x80 | tmp, NONE, NONE, false, immBit / 8); + } + db(imm, immBit / 8); + } + void opIncDec(const Operand& op, int code, int ext) + { + verifyMemHasSize(op); +#ifndef XBYAK64 + if (op.isREG() && !op.isBit(8)) { + rex(op); db(code | op.getIdx()); + return; + } +#endif + code = 0xFE; + if (op.isREG()) { + opModR(Reg(ext, Operand::REG, op.getBit()), op.getReg(), code); + } else { + opModM(op.getAddress(), Reg(ext, Operand::REG, op.getBit()), code); + } + } + void opPushPop(const Operand& op, int code, int ext, int alt) + { + int bit = op.getBit(); + if (bit == 16 || bit == BIT) { + if (bit == 16) db(0x66); + if (op.isREG()) { + if (op.getReg().getIdx() >= 8) db(0x41); + db(alt | (op.getIdx() & 7)); + return; + } + if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, 32), code); + return; + } + } + throw Error(ERR_BAD_COMBINATION); + } + void verifyMemHasSize(const Operand& op) const + { + if (op.isMEM() && op.getBit() == 0) throw Error(ERR_MEM_SIZE_IS_NOT_SPECIFIED); + } + /* + mov(r, imm) = db(imm, mov_imm(r, imm)) + */ + int mov_imm(const Reg& reg, size_t imm) + { + int bit = reg.getBit(); + const int idx = reg.getIdx(); + int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3); + if (bit == 64 && (imm & ~size_t(0xffffffffu)) == 0) { + rex(Reg32(idx)); + bit = 32; + } else { + rex(reg); + if (bit == 64 && inner::IsInInt32(imm)) { + db(0xC7); + code = 0xC0; + bit = 32; + } + } + db(code | (idx & 7)); + return bit / 8; + } + template + void putL_inner(T& label, bool relative = false, size_t disp = 0) + { + const int jmpSize = relative ? 4 : (int)sizeof(size_t); + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { + if (relative) { + db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize); + } else if (isAutoGrow()) { + db(uint64(0), jmpSize); + save(size_ - jmpSize, offset, jmpSize, inner::LaddTop); + } else { + db(size_t(top_) + offset, jmpSize); + } + return; + } + db(uint64(0), jmpSize); + JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp); + labelMgr_.addUndefinedLabel(label, jmp); + } + void opMovxx(const Reg& reg, const Operand& op, uint8 code) + { + if (op.isBit(32)) throw Error(ERR_BAD_COMBINATION); + int w = op.isBit(16); +#ifdef XBYAK64 + if (op.isHigh8bit()) throw Error(ERR_BAD_COMBINATION); +#endif + bool cond = reg.isREG() && (reg.getBit() > op.getBit()); + opModRM(reg, op, cond && op.isREG(), cond && op.isMEM(), 0x0F, code | w); + } + void opFpuMem(const Address& addr, uint8 m16, uint8 m32, uint8 m64, uint8 ext, uint8 m64ext) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + uint8 code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0; + if (!code) throw Error(ERR_BAD_MEM_SIZE); + if (m64ext && addr.isBit(64)) ext = m64ext; + + rex(addr, st0); + db(code); + opAddr(addr, ext); + } + // use code1 if reg1 == st0 + // use code2 if reg1 != st0 && reg2 == st0 + void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32 code1, uint32 code2) + { + uint32 code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0; + if (!code) throw Error(ERR_BAD_ST_COMBINATION); + db(uint8(code >> 8)); + db(uint8(code | (reg1.getIdx() | reg2.getIdx()))); + } + void opFpu(const Fpu& reg, uint8 code1, uint8 code2) + { + db(code1); db(code2 | reg.getIdx()); + } + void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE) + { + if (op2.isMEM()) { + const Address& addr = op2.getAddress(); + const RegExp& regExp = addr.getRegExp(); + const Reg& base = regExp.getBase(); + const Reg& index = regExp.getIndex(); + if (BIT == 64 && addr.is32bit()) db(0x67); + int disp8N = 0; + bool x = index.isExtIdx(); + if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) { + int aaa = addr.getOpmaskIdx(); + if (aaa && !(type & T_M_K)) throw Error(ERR_INVALID_OPMASK_WITH_MEMORY); + bool b = false; + if (addr.isBroadcast()) { + if (!(type & (T_B32 | T_B64))) throw Error(ERR_INVALID_BROADCAST); + b = true; + } + int VL = regExp.isVsib() ? index.getBit() : 0; + disp8N = evex(r, base, p1, type, code, x, b, aaa, VL, index.isExtIdx2()); + } else { + vex(r, base, p1, type, code, x); + } + opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0); + } else { + const Reg& base = op2.getReg(); + if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) { + evex(r, base, p1, type, code); + } else { + vex(r, base, p1, type, code); + } + setModRM(3, r.getIdx(), base.getIdx()); + } + if (imm8 != NONE) db(imm8); + } + // (r, r, r/m) if isR_R_RM + // (r, r/m, r) + void opGpr(const Reg32e& r, const Operand& op1, const Operand& op2, int type, uint8 code, bool isR_R_RM, int imm8 = NONE) + { + const Operand *p1 = &op1; + const Operand *p2 = &op2; + if (!isR_R_RM) std::swap(p1, p2); + const unsigned int bit = r.getBit(); + if (p1->getBit() != bit || (p2->isREG() && p2->getBit() != bit)) throw Error(ERR_BAD_COMBINATION); + type |= (bit == 64) ? T_W1 : T_W0; + opVex(r, p1, *p2, type, code, imm8); + } + void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE) + { + const Xmm *x2 = static_cast(&op1); + const Operand *op = &op2; + if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1) + x2 = &x1; + op = &op1; + } + // (x1, x2, op) + if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) throw Error(ERR_BAD_COMBINATION); + opVex(x1, x2, *op, type, code0, imm8); + } + void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE) + { + if (!op3.isMEM() && (x2.getKind() != op3.getKind())) throw Error(ERR_BAD_COMBINATION); + opVex(k, &x2, op3, type, code0, imm8); + } + // (x, x/m), (y, x/m256), (z, y/m) + void checkCvt1(const Operand& x, const Operand& op) const + { + if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) throw Error(ERR_BAD_COMBINATION); + } + // (x, x/m), (x, y/m256), (y, z/m) + void checkCvt2(const Xmm& x, const Operand& op) const + { + if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) throw Error(ERR_BAD_COMBINATION); + } + void opCvt2(const Xmm& x, const Operand& op, int type, int code) + { + checkCvt2(x, op); + Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM; + opVex(x.copyAndSetKind(kind), &xm0, op, type, code); + } + void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int type64, int type32, uint8 code) + { + if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) throw Error(ERR_BAD_SIZE_OF_REGISTER); + Xmm x(op.getIdx()); + const Operand *p = op.isREG() ? &x : &op; + opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code); + } + const Xmm& cvtIdx0(const Operand& x) const + { + return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0; + } + // support (x, x/m, imm), (y, y/m, imm) + void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, int type, int code, int imm8 = NONE) + { + opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8); + } + // QQQ:need to refactor + void opSp1(const Reg& reg, const Operand& op, uint8 pref, uint8 code0, uint8 code1) + { + if (reg.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM()); + if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); + if (is16bit) db(0x66); + db(pref); opModRM(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, op.isREG(), true, code0, code1); + } + void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, int type, uint8 code, int mode) + { + const RegExp& regExp = addr.getRegExp(); + if (!regExp.isVsib(128 | 256)) throw Error(ERR_BAD_VSIB_ADDRESSING); + const int y_vx_y = 0; + const int y_vy_y = 1; +// const int x_vy_x = 2; + const bool isAddrYMM = regExp.getIndex().getBit() == 256; + if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) { + bool isOK = false; + if (mode == y_vx_y) { + isOK = x1.isYMM() && !isAddrYMM && x2.isYMM(); + } else if (mode == y_vy_y) { + isOK = x1.isYMM() && isAddrYMM && x2.isYMM(); + } else { // x_vy_x + isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM(); + } + if (!isOK) throw Error(ERR_BAD_VSIB_ADDRESSING); + } + opAVX_X_X_XM(isAddrYMM ? Ymm(x1.getIdx()) : x1, isAddrYMM ? Ymm(x2.getIdx()) : x2, addr, type, code); + } + enum { + xx_yy_zz = 0, + xx_yx_zy = 1, + xx_xy_yz = 2 + }; + void checkGather2(const Xmm& x1, const Reg& x2, int mode) const + { + if (x1.isXMM() && x2.isXMM()) return; + switch (mode) { + case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return; + break; + case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return; + break; + case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return; + break; + } + throw Error(ERR_BAD_VSIB_ADDRESSING); + } + void opGather2(const Xmm& x, const Address& addr, int type, uint8 code, int mode) + { + if (x.hasZero()) throw Error(ERR_INVALID_ZERO); + checkGather2(x, addr.getRegExp().getIndex(), mode); + opVex(x, 0, addr, type, code); + } + /* + xx_xy_yz ; mode = true + xx_xy_xz ; mode = false + */ + void opVmov(const Operand& op, const Xmm& x, int type, uint8 code, bool mode) + { + if (mode) { + if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) throw Error(ERR_BAD_COMBINATION); + } else { + if (!op.isMEM() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); + } + opVex(x, 0, op, type, code); + } + void opGatherFetch(const Address& addr, const Xmm& x, int type, uint8 code, Operand::Kind kind) + { + if (addr.hasZero()) throw Error(ERR_INVALID_ZERO); + if (addr.getRegExp().getIndex().getKind() != kind) throw Error(ERR_BAD_VSIB_ADDRESSING); + opVex(x, 0, addr, type, code); + } +public: + unsigned int getVersion() const { return VERSION; } + using CodeArray::db; + const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7; + const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7; + const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; + const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7; + const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7; + const Ymm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7; + const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi; + const Reg16 ax, cx, dx, bx, sp, bp, si, di; + const Reg8 al, cl, dl, bl, ah, ch, dh, bh; + const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM + const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b} + const Fpu st0, st1, st2, st3, st4, st5, st6, st7; + const Opmask k0, k1, k2, k3, k4, k5, k6, k7; + const BoundsReg bnd0, bnd1, bnd2, bnd3; + const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae} + const EvexModifierZero T_z; // {z} +#ifdef XBYAK64 + const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15; + const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d; + const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w; + const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b; + const Reg8 spl, bpl, sil, dil; + const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15; + const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23; + const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31; + const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15; + const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23; + const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15; + const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23; + const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience + const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23; + const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31; + const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15; + const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23; + const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31; + const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15; + const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23; + const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31; + const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT + const Segment es, cs, ss, ds, fs, gs; +#endif + void L(const std::string& label) { labelMgr_.defineSlabel(label); } + void L(Label& label) { labelMgr_.defineClabel(label); } + Label L() { Label label; L(label); return label; } + void inLocalLabel() { labelMgr_.enterLocal(); } + void outLocalLabel() { labelMgr_.leaveLocal(); } + /* + assign src to dst + require + dst : does not used by L() + src : used by L() + */ + void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); } + /* + put address of label to buffer + @note the put size is 4(32-bit), 8(64-bit) + */ + void putL(std::string label) { putL_inner(label); } + void putL(const Label& label) { putL_inner(label); } + + void jmp(const Operand& op) { opR_ModM(op, BIT, 4, 0xFF, NONE, NONE, true); } + void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); } + void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); } + + void call(const Operand& op) { opR_ModM(op, 16 | i32e, 2, 0xFF, NONE, NONE, true); } + // call(string label), not const std::string& + void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + void call(const char *label) { call(std::string(label)); } + void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + // call(function pointer) +#ifdef XBYAK_VARIADIC_TEMPLATE + template + void call(Ret(*func)(Params...)) { call(reinterpret_cast(func)); } +#endif + void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); } + + void test(const Operand& op, const Reg& reg) + { + opModRM(reg, op, op.isREG() && (op.getKind() == reg.getKind()), op.isMEM(), 0x84); + } + void test(const Operand& op, uint32 imm) + { + verifyMemHasSize(op); + int immSize = (std::min)(op.getBit() / 8, 4U); + if (op.isREG() && op.getIdx() == 0) { // al, ax, eax + rex(op); + db(0xA8 | (op.isBit(8) ? 0 : 1)); + } else { + opR_ModM(op, 0, 0, 0xF6, NONE, NONE, false, immSize); + } + db(imm, immSize); + } + void imul(const Reg& reg, const Operand& op) + { + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x0F, 0xAF); + } + void imul(const Reg& reg, const Operand& op, int imm) + { + int s = inner::IsInDisp8(imm) ? 1 : 0; + int immSize = s ? 1 : reg.isREG(16) ? 2 : 4; + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x69 | (s << 1), NONE, NONE, immSize); + db(imm, immSize); + } + void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); } + void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); } + void push(const AddressFrame& af, uint32 imm) + { + if (af.bit_ == 8 && inner::IsInDisp8(imm)) { + db(0x6A); db(imm); + } else if (af.bit_ == 16 && isInDisp16(imm)) { + db(0x66); db(0x68); dw(imm); + } else { + db(0x68); dd(imm); + } + } + /* use "push(word, 4)" if you want "push word 4" */ + void push(uint32 imm) + { + if (inner::IsInDisp8(imm)) { + push(byte, imm); + } else { + push(dword, imm); + } + } + void mov(const Operand& reg1, const Operand& reg2) + { + const Reg *reg = 0; + const Address *addr = 0; + uint8 code = 0; + if (reg1.isREG() && reg1.getIdx() == 0 && reg2.isMEM()) { // mov eax|ax|al, [disp] + reg = ®1.getReg(); + addr= ®2.getAddress(); + code = 0xA0; + } else + if (reg1.isMEM() && reg2.isREG() && reg2.getIdx() == 0) { // mov [disp], eax|ax|al + reg = ®2.getReg(); + addr= ®1.getAddress(); + code = 0xA2; + } +#ifdef XBYAK64 + if (addr && addr->is64bitDisp()) { + if (code) { + rex(*reg); + db(reg1.isREG(8) ? 0xA0 : reg1.isREG() ? 0xA1 : reg2.isREG(8) ? 0xA2 : 0xA3); + db(addr->getDisp(), 8); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } else +#else + if (code && addr->isOnlyDisp()) { + rex(*reg, *addr); + db(code | (reg->isBit(8) ? 0 : 1)); + dd(static_cast(addr->getDisp())); + } else +#endif + { + opRM_RM(reg1, reg2, 0x88); + } + } + void mov(const Operand& op, size_t imm) + { + if (op.isREG()) { + const int size = mov_imm(op.getReg(), imm); + db(imm, size); + } else if (op.isMEM()) { + verifyMemHasSize(op); + int immSize = op.getBit() / 8; + if (immSize <= 4) { + sint64 s = sint64(imm) >> (immSize * 8); + if (s != 0 && s != -1) throw Error(ERR_IMM_IS_TOO_BIG); + } else { + if (!inner::IsInInt32(imm)) throw Error(ERR_IMM_IS_TOO_BIG); + immSize = 4; + } + opModM(op.getAddress(), Reg(0, Operand::REG, op.getBit()), 0xC6, NONE, NONE, immSize); + db(static_cast(imm), immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void mov(const NativeReg& reg, const char *label) // can't use std::string + { + if (label == 0) { + mov(static_cast(reg), 0); // call imm + return; + } + mov_imm(reg, dummyAddr); + putL(label); + } + void mov(const NativeReg& reg, const Label& label) + { + mov_imm(reg, dummyAddr); + putL(label); + } + void xchg(const Operand& op1, const Operand& op2) + { + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) { + p1 = &op2; p2 = &op1; + } + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0) +#ifdef XBYAK64 + && (p2->getIdx() != 0 || !p1->isREG(32)) +#endif + ) { + rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7)); + return; + } + opModRM(*p1, *p2, (p1->isREG() && p2->isREG() && (p1->getBit() == p2->getBit())), p2->isMEM(), 0x86 | (p1->isBit(8) ? 0 : 1)); + } + +#ifndef XBYAK_DISABLE_SEGMENT + void push(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x06); break; + case Segment::cs: db(0x0E); break; + case Segment::ss: db(0x16); break; + case Segment::ds: db(0x1E); break; + case Segment::fs: db(0x0F); db(0xA0); break; + case Segment::gs: db(0x0F); db(0xA8); break; + default: + assert(0); + } + } + void pop(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x07); break; + case Segment::cs: throw Error(ERR_BAD_COMBINATION); + case Segment::ss: db(0x17); break; + case Segment::ds: db(0x1F); break; + case Segment::fs: db(0x0F); db(0xA1); break; + case Segment::gs: db(0x0F); db(0xA9); break; + default: + assert(0); + } + } + void putSeg(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x2E); break; + case Segment::cs: db(0x36); break; + case Segment::ss: db(0x3E); break; + case Segment::ds: db(0x26); break; + case Segment::fs: db(0x64); break; + case Segment::gs: db(0x65); break; + default: + assert(0); + } + } + void mov(const Operand& op, const Segment& seg) + { + opModRM(Reg8(seg.getIdx()), op, op.isREG(16|i32e), op.isMEM(), 0x8C); + } + void mov(const Segment& seg, const Operand& op) + { + opModRM(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast(op.getReg().cvt32()) : op, op.isREG(16|i32e), op.isMEM(), 0x8E); + } +#endif + + enum { NONE = 256 }; + // constructor + CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0) + : CodeArray(maxSize, userPtr, allocator) + , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7) + , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7) + , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7) + , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7) + // for my convenience + , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7) + , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7) + , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7) + + , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI) + , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI) + , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH) + , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512) + , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true) + , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7) + , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7) + , bnd0(0), bnd1(1), bnd2(2), bnd3(3) + , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE) + , T_z() +#ifdef XBYAK64 + , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15) + , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15) + , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15) + , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15) + , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true) + , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15) + , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23) + , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31) + , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15) + , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23) + , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31) + , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15) + , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23) + , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31) + // for my convenience + , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15) + , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23) + , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31) + , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15) + , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23) + , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31) + , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15) + , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23) + , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31) + , rip() +#endif +#ifndef XBYAK_DISABLE_SEGMENT + , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs) +#endif + { + labelMgr_.set(this); + } + void reset() + { + resetSize(); + labelMgr_.reset(); + labelMgr_.set(this); + } + bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); } + /* + MUST call ready() to complete generating code if you use AutoGrow mode. + It is not necessary for the other mode if hasUndefinedLabel() is true. + */ + void ready(ProtectMode mode = PROTECT_RWE) + { + if (hasUndefinedLabel()) throw Error(ERR_LABEL_IS_NOT_FOUND); + if (isAutoGrow()) { + calcJmpAddress(); + if (useProtect()) setProtectMode(mode); + } + } + // set read/exec + void readyRE() { return ready(PROTECT_RE); } +#ifdef XBYAK_TEST + void dump(bool doClear = true) + { + CodeArray::dump(); + if (doClear) size_ = 0; + } +#endif + +#ifdef XBYAK_UNDEF_JNL + #undef jnl +#endif + + /* + use single byte nop if useMultiByteNop = false + */ + void nop(size_t size = 1, bool useMultiByteNop = true) + { + if (!useMultiByteNop) { + for (size_t i = 0; i < size; i++) { + db(0x90); + } + return; + } + /* + Intel Architectures Software Developer's Manual Volume 2 + recommended multi-byte sequence of NOP instruction + AMD and Intel seem to agree on the same sequences for up to 9 bytes: + https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf + */ + static const uint8 nopTbl[9][9] = { + {0x90}, + {0x66, 0x90}, + {0x0F, 0x1F, 0x00}, + {0x0F, 0x1F, 0x40, 0x00}, + {0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00}, + {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + }; + const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]); + while (size > 0) { + size_t len = (std::min)(n, size); + const uint8 *seq = nopTbl[len - 1]; + db(seq, len); + size -= len; + } + } + +#ifndef XBYAK_DONT_READ_LIST +#include "xbyak_mnemonic.h" + /* + use single byte nop if useMultiByteNop = false + */ + void align(size_t x = 16, bool useMultiByteNop = true) + { + if (x == 1) return; + if (x < 1 || (x & (x - 1))) throw Error(ERR_BAD_ALIGN); + if (isAutoGrow() && x > inner::ALIGN_PAGE_SIZE) fprintf(stderr, "warning:autoGrow mode does not support %d align\n", (int)x); + size_t remain = size_t(getCurr()) % x; + if (remain) { + nop(x - remain, useMultiByteNop); + } + } +#endif +}; + +namespace util { +static const Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7); +static const Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7); +static const Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7); +static const Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7); +static const Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI); +static const Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI); +static const Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH); +static const AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512); +static const AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true); +static const Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7); +static const Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7); +static const BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3); +static const EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE); +static const EvexModifierZero T_z; +#ifdef XBYAK64 +static const Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15); +static const Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15); +static const Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15); +static const Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true); +static const Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15); +static const Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23); +static const Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31); +static const Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15); +static const Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23); +static const Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31); +static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15); +static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23); +static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31); +static const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT +static const Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs); +#endif +} // util + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +} // end of namespace + +#endif // XBYAK_XBYAK_H_ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h new file mode 100644 index 0000000000..a22e5224c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h @@ -0,0 +1,303 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +enum { + B00000000= 0, + B00000001= 1, + B00000010= 2, + B00000011= 3, + B00000100= 4, + B00000101= 5, + B00000110= 6, + B00000111= 7, + B00001000= 8, + B00001001= 9, + B00001010= 10, + B00001011= 11, + B00001100= 12, + B00001101= 13, + B00001110= 14, + B00001111= 15, + B00010000= 16, + B00010001= 17, + B00010010= 18, + B00010011= 19, + B00010100= 20, + B00010101= 21, + B00010110= 22, + B00010111= 23, + B00011000= 24, + B00011001= 25, + B00011010= 26, + B00011011= 27, + B00011100= 28, + B00011101= 29, + B00011110= 30, + B00011111= 31, + B00100000= 32, + B00100001= 33, + B00100010= 34, + B00100011= 35, + B00100100= 36, + B00100101= 37, + B00100110= 38, + B00100111= 39, + B00101000= 40, + B00101001= 41, + B00101010= 42, + B00101011= 43, + B00101100= 44, + B00101101= 45, + B00101110= 46, + B00101111= 47, + B00110000= 48, + B00110001= 49, + B00110010= 50, + B00110011= 51, + B00110100= 52, + B00110101= 53, + B00110110= 54, + B00110111= 55, + B00111000= 56, + B00111001= 57, + B00111010= 58, + B00111011= 59, + B00111100= 60, + B00111101= 61, + B00111110= 62, + B00111111= 63, + B01000000= 64, + B01000001= 65, + B01000010= 66, + B01000011= 67, + B01000100= 68, + B01000101= 69, + B01000110= 70, + B01000111= 71, + B01001000= 72, + B01001001= 73, + B01001010= 74, + B01001011= 75, + B01001100= 76, + B01001101= 77, + B01001110= 78, + B01001111= 79, + B01010000= 80, + B01010001= 81, + B01010010= 82, + B01010011= 83, + B01010100= 84, + B01010101= 85, + B01010110= 86, + B01010111= 87, + B01011000= 88, + B01011001= 89, + B01011010= 90, + B01011011= 91, + B01011100= 92, + B01011101= 93, + B01011110= 94, + B01011111= 95, + B01100000= 96, + B01100001= 97, + B01100010= 98, + B01100011= 99, + B01100100= 100, + B01100101= 101, + B01100110= 102, + B01100111= 103, + B01101000= 104, + B01101001= 105, + B01101010= 106, + B01101011= 107, + B01101100= 108, + B01101101= 109, + B01101110= 110, + B01101111= 111, + B01110000= 112, + B01110001= 113, + B01110010= 114, + B01110011= 115, + B01110100= 116, + B01110101= 117, + B01110110= 118, + B01110111= 119, + B01111000= 120, + B01111001= 121, + B01111010= 122, + B01111011= 123, + B01111100= 124, + B01111101= 125, + B01111110= 126, + B01111111= 127, + B10000000= 128, + B10000001= 129, + B10000010= 130, + B10000011= 131, + B10000100= 132, + B10000101= 133, + B10000110= 134, + B10000111= 135, + B10001000= 136, + B10001001= 137, + B10001010= 138, + B10001011= 139, + B10001100= 140, + B10001101= 141, + B10001110= 142, + B10001111= 143, + B10010000= 144, + B10010001= 145, + B10010010= 146, + B10010011= 147, + B10010100= 148, + B10010101= 149, + B10010110= 150, + B10010111= 151, + B10011000= 152, + B10011001= 153, + B10011010= 154, + B10011011= 155, + B10011100= 156, + B10011101= 157, + B10011110= 158, + B10011111= 159, + B10100000= 160, + B10100001= 161, + B10100010= 162, + B10100011= 163, + B10100100= 164, + B10100101= 165, + B10100110= 166, + B10100111= 167, + B10101000= 168, + B10101001= 169, + B10101010= 170, + B10101011= 171, + B10101100= 172, + B10101101= 173, + B10101110= 174, + B10101111= 175, + B10110000= 176, + B10110001= 177, + B10110010= 178, + B10110011= 179, + B10110100= 180, + B10110101= 181, + B10110110= 182, + B10110111= 183, + B10111000= 184, + B10111001= 185, + B10111010= 186, + B10111011= 187, + B10111100= 188, + B10111101= 189, + B10111110= 190, + B10111111= 191, + B11000000= 192, + B11000001= 193, + B11000010= 194, + B11000011= 195, + B11000100= 196, + B11000101= 197, + B11000110= 198, + B11000111= 199, + B11001000= 200, + B11001001= 201, + B11001010= 202, + B11001011= 203, + B11001100= 204, + B11001101= 205, + B11001110= 206, + B11001111= 207, + B11010000= 208, + B11010001= 209, + B11010010= 210, + B11010011= 211, + B11010100= 212, + B11010101= 213, + B11010110= 214, + B11010111= 215, + B11011000= 216, + B11011001= 217, + B11011010= 218, + B11011011= 219, + B11011100= 220, + B11011101= 221, + B11011110= 222, + B11011111= 223, + B11100000= 224, + B11100001= 225, + B11100010= 226, + B11100011= 227, + B11100100= 228, + B11100101= 229, + B11100110= 230, + B11100111= 231, + B11101000= 232, + B11101001= 233, + B11101010= 234, + B11101011= 235, + B11101100= 236, + B11101101= 237, + B11101110= 238, + B11101111= 239, + B11110000= 240, + B11110001= 241, + B11110010= 242, + B11110011= 243, + B11110100= 244, + B11110101= 245, + B11110110= 246, + B11110111= 247, + B11111000= 248, + B11111001= 249, + B11111010= 250, + B11111011= 251, + B11111100= 252, + B11111101= 253, + B11111110= 254, + B11111111= 255 +}; diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h new file mode 100644 index 0000000000..28d2d222f9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h @@ -0,0 +1,2017 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +const char *getVersionString() const { return "5.76"; } +void adc(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x10, 2); } +void adc(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x10); } +void adcx(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0x66, isREG32_REG32orMEM, NONE, 0x38); } +void add(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x00, 0); } +void add(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x00); } +void addpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x66, isXMM_XMMorMEM); } +void addps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x100, isXMM_XMMorMEM); } +void addsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF2, isXMM_XMMorMEM); } +void addss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF3, isXMM_XMMorMEM); } +void addsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0x66, isXMM_XMMorMEM); } +void addsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0xF2, isXMM_XMMorMEM); } +void adox(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0xF3, isREG32_REG32orMEM, NONE, 0x38); } +void aesdec(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDE, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesdeclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDC, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDD, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesimc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDB, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void and_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x20, 4); } +void and_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x20); } +void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_0F38, 0xf2, true); } +void andnpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x66, isXMM_XMMorMEM); } +void andnps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x100, isXMM_XMMorMEM); } +void andpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x66, isXMM_XMMorMEM); } +void andps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x100, isXMM_XMMorMEM); } +void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf7, false); } +void blendpd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0D, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void blendps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0C, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void blendvpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blendvps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blsi(const Reg32e& r, const Operand& op) { opGpr(Reg32e(3, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsmsk(const Reg32e& r, const Operand& op) { opGpr(Reg32e(2, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsr(const Reg32e& r, const Operand& op) { opGpr(Reg32e(1, r.getBit()), op, r, T_0F38, 0xf3, false); } +void bnd() { db(0xF2); } +void bndcl(const BoundsReg& bnd, const Operand& op) { db(0xF3); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndcn(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1B, NONE, !op.isMEM()); } +void bndcu(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, 0x0F, 0x1A); } +void bndmk(const BoundsReg& bnd, const Address& addr) { db(0xF3); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const Address& addr, const BoundsReg& bnd) { db(0x66); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const BoundsReg& bnd, const Operand& op) { db(0x66); opModRM(bnd, op, op.isBNDREG(), op.isMEM(), 0x0F, 0x1A); } +void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, 0x0F, 0x1B); } +void bsf(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBC); } +void bsr(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBD); } +void bswap(const Reg32e& reg) { opModR(Reg32(1), reg, 0x0F); } +void bt(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xA3); } +void bt(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 4, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btc(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xBB); } +void btc(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 7, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btr(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xB3); } +void btr(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 6, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bts(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xAB); } +void bts(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 5, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf5, false); } +void cbw() { db(0x66); db(0x98); } +void cdq() { db(0x99); } +void clc() { db(0xF8); } +void cld() { db(0xFC); } +void clflush(const Address& addr) { opModM(addr, Reg32(7), 0x0F, 0xAE); } +void cli() { db(0xFA); } +void cmc() { db(0xF5); } +void cmova(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmove(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmovg(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovna(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovnae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovnb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovnbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovnc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovne(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovng(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovnge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovnl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovnle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovno(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 1); }//-V524 +void cmovnp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovns(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 9); }//-V524 +void cmovnz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 0); }//-V524 +void cmovp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovs(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 8); }//-V524 +void cmovz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmp(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x38, 7); } +void cmp(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x38); } +void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); } +void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); } +void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); } +void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); } +void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); } +void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); } +void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); } +void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); } +void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); } +void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); } +void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); } +void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); } +void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); } +void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); } +void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); } +void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); } +void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); } +void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); } +void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); } +void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); } +void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); } +void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); } +void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); } +void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); } +void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); } +void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); } +void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); } +void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); } +void cmppd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x66, isXMM_XMMorMEM, imm8); } +void cmpps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x100, isXMM_XMMorMEM, imm8); } +void cmpsb() { db(0xA6); } +void cmpsd() { db(0xA7); } +void cmpsd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF2, isXMM_XMMorMEM, imm8); } +void cmpss(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF3, isXMM_XMMorMEM, imm8); } +void cmpsw() { db(0x66); db(0xA7); } +void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); } +void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); } +void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); } +void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); } +void cmpxchg(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xB0 | (reg.isBit(8) ? 0 : 1)); } +void cmpxchg8b(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0xC7); } +void comisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x66, isXMM_XMMorMEM); } +void comiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x100, isXMM_XMMorMEM); } +void cpuid() { db(0x0F); db(0xA2); } +void crc32(const Reg32e& reg, const Operand& op) { if (reg.isBit(32) && op.isBit(16)) db(0x66); db(0xF2); opModRM(reg, op, op.isREG(), op.isMEM(), 0x0F, 0x38, 0xF0 | (op.isBit(8) ? 0 : 1)); } +void cvtdq2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF3, isXMM_XMMorMEM); } +void cvtdq2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x100, isXMM_XMMorMEM); } +void cvtpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF2, isXMM_XMMorMEM); } +void cvtpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x66, isMMX_XMMorMEM); } +void cvtpd2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x66, isXMM_XMMorMEM); } +void cvtpi2pd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x66, isXMM_MMXorMEM); } +void cvtpi2ps(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x100, isXMM_MMXorMEM); } +void cvtps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x66, isXMM_XMMorMEM); } +void cvtps2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x100, isXMM_XMMorMEM); } +void cvtps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x100, isMMX_XMMorMEM); } +void cvtsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF2, isREG32_XMMorMEM); } +void cvtsd2ss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF2, isXMM_XMMorMEM); } +void cvtsi2sd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF2, isXMM_REG32orMEM); } +void cvtsi2ss(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF3, isXMM_REG32orMEM); } +void cvtss2sd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF3, isXMM_XMMorMEM); } +void cvtss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF3, isREG32_XMMorMEM); } +void cvttpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0x66, isXMM_XMMorMEM); } +void cvttpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x66, isMMX_XMMorMEM); } +void cvttps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0xF3, isXMM_XMMorMEM); } +void cvttps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x100, isMMX_XMMorMEM); } +void cvttsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF2, isREG32_XMMorMEM); } +void cvttss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF3, isREG32_XMMorMEM); } +void cwd() { db(0x66); db(0x99); } +void cwde() { db(0x98); } +void dec(const Operand& op) { opIncDec(op, 0x48, 1); } +void div(const Operand& op) { opR_ModM(op, 0, 6, 0xF6); } +void divpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x66, isXMM_XMMorMEM); } +void divps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x100, isXMM_XMMorMEM); } +void divsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF2, isXMM_XMMorMEM); } +void divss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF3, isXMM_XMMorMEM); } +void dppd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void dpps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void emms() { db(0x0F); db(0x77); } +void extractps(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x17, imm); } +void f2xm1() { db(0xD9); db(0xF0); } +void fabs() { db(0xD9); db(0xE1); } +void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); } +void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); } +void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); } +void faddp() { db(0xDE); db(0xC1); } +void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); } +void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); } +void fchs() { db(0xD9); db(0xE0); } +void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); } +void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); } +void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); } +void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); } +void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); } +void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); } +void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); } +void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); } +void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); } +void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); } +void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); } +void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); } +void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); } +void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); } +void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); } +void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); } +void fcom() { db(0xD8); db(0xD1); } +void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); } +void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); } +void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); } +void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); } +void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); } +void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); } +void fcomp() { db(0xD8); db(0xD9); } +void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); } +void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); } +void fcompp() { db(0xDE); db(0xD9); } +void fcos() { db(0xD9); db(0xFF); } +void fdecstp() { db(0xD9); db(0xF6); } +void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); } +void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); } +void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); } +void fdivp() { db(0xDE); db(0xF9); } +void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); } +void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); } +void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); } +void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); } +void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); } +void fdivrp() { db(0xDE); db(0xF1); } +void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); } +void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); } +void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); } +void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); } +void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); } +void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); } +void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); } +void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); } +void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); } +void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); } +void fincstp() { db(0xD9); db(0xF7); } +void finit() { db(0x9B); db(0xDB); db(0xE3); } +void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); } +void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); } +void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); } +void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); } +void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); } +void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); } +void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); } +void fld1() { db(0xD9); db(0xE8); } +void fldcw(const Address& addr) { opModM(addr, Reg32(5), 0xD9, 0x100); } +void fldl2e() { db(0xD9); db(0xEA); } +void fldl2t() { db(0xD9); db(0xE9); } +void fldlg2() { db(0xD9); db(0xEC); } +void fldln2() { db(0xD9); db(0xED); } +void fldpi() { db(0xD9); db(0xEB); } +void fldz() { db(0xD9); db(0xEE); } +void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); } +void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); } +void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); } +void fmulp() { db(0xDE); db(0xC9); } +void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); } +void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); } +void fninit() { db(0xDB); db(0xE3); } +void fnop() { db(0xD9); db(0xD0); } +void fpatan() { db(0xD9); db(0xF3); } +void fprem() { db(0xD9); db(0xF8); } +void fprem1() { db(0xD9); db(0xF5); } +void fptan() { db(0xD9); db(0xF2); } +void frndint() { db(0xD9); db(0xFC); } +void fscale() { db(0xD9); db(0xFD); } +void fsin() { db(0xD9); db(0xFE); } +void fsincos() { db(0xD9); db(0xFB); } +void fsqrt() { db(0xD9); db(0xFA); } +void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); } +void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); } +void fstcw(const Address& addr) { db(0x9B); opModM(addr, Reg32(7), 0xD9, NONE); } +void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); } +void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); } +void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); } +void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); } +void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); } +void fsubp() { db(0xDE); db(0xE9); } +void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); } +void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); } +void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); } +void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); } +void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); } +void fsubrp() { db(0xDE); db(0xE1); } +void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); } +void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); } +void ftst() { db(0xD9); db(0xE4); } +void fucom() { db(0xDD); db(0xE1); } +void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); } +void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); } +void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); } +void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); } +void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); } +void fucomp() { db(0xDD); db(0xE9); } +void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); } +void fucompp() { db(0xDA); db(0xE9); } +void fwait() { db(0x9B); } +void fxam() { db(0xD9); db(0xE5); } +void fxch() { db(0xD9); db(0xC9); } +void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); } +void fxtract() { db(0xD9); db(0xF4); } +void fyl2x() { db(0xD9); db(0xF1); } +void fyl2xp1() { db(0xD9); db(0xF9); } +void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCE, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void gf2p8mulb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void haddpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0x66, isXMM_XMMorMEM); } +void haddps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0xF2, isXMM_XMMorMEM); } +void hsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0x66, isXMM_XMMorMEM); } +void hsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0xF2, isXMM_XMMorMEM); } +void idiv(const Operand& op) { opR_ModM(op, 0, 7, 0xF6); } +void imul(const Operand& op) { opR_ModM(op, 0, 5, 0xF6); } +void inc(const Operand& op) { opIncDec(op, 0x40, 0); } +void insertps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524 +void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524 +void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524 +void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524 +void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524 +void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524 +void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524 +void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524 +void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524 +void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524 +void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524 +void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524 +void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524 +void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524 +void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524 +void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524 +void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524 +void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524 +void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524 +void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524 +void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524 +void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524 +void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524 +void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524 +void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524 +void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524 +void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524 +void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524 +void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524 +void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524 +void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524 +void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524 +void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524 +void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524 +void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void lahf() { db(0x9F); } +void lddqu(const Xmm& xmm, const Address& addr) { db(0xF2); opModM(addr, xmm, 0x0F, 0xF0); } +void ldmxcsr(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0xAE); } +void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModM(addr, reg, 0x8D); } +void lfence() { db(0x0F); db(0xAE); db(0xE8); } +void lock() { db(0xF0); } +void lzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBD); } +void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { db(0x66); opModR(reg1, reg2, 0x0F, 0xF7); } +void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) throw Error(ERR_BAD_COMBINATION); opModR(reg1, reg2, 0x0F, 0xF7); } +void maxpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x66, isXMM_XMMorMEM); } +void maxps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x100, isXMM_XMMorMEM); } +void maxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF2, isXMM_XMMorMEM); } +void maxss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF3, isXMM_XMMorMEM); } +void mfence() { db(0x0F); db(0xAE); db(0xF0); } +void minpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x66, isXMM_XMMorMEM); } +void minps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x100, isXMM_XMMorMEM); } +void minsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF2, isXMM_XMMorMEM); } +void minss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF3, isXMM_XMMorMEM); } +void monitor() { db(0x0F); db(0x01); db(0xC8); } +void movapd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x29); } +void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x66); } +void movaps(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x29); } +void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x100); } +void movbe(const Address& addr, const Reg& reg) { opModM(addr, reg, 0x0F, 0x38, 0xF1); } +void movbe(const Reg& reg, const Address& addr) { opModM(addr, reg, 0x0F, 0x38, 0xF0); } +void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x7E); } +void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x6E); } +void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movddup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF2, isXMM_XMMorMEM, NONE, NONE); } +void movdq2q(const Mmx& mmx, const Xmm& xmm) { db(0xF2); opModR(mmx, xmm, 0x0F, 0xD6); } +void movdqa(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0x66); } +void movdqu(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0xF3); } +void movhlps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x12); } +void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x66); } +void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x100); } +void movlhps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x16); } +void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x66); } +void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x100); } +void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); } +void movmskps(const Reg32e& reg, const Xmm& xmm) { opModR(reg, xmm, 0x0F, 0x50); } +void movntdq(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0xE7); } +void movntdqa(const Xmm& xmm, const Address& addr) { db(0x66); opModM(addr, xmm, 0x0F, 0x38, 0x2A); } +void movnti(const Address& addr, const Reg32e& reg) { opModM(addr, reg, 0x0F, 0xC3); } +void movntpd(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0x2B); } +void movntps(const Address& addr, const Xmm& xmm) { opModM(addr, Mmx(xmm.getIdx()), 0x0F, 0x2B); } +void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) throw Error(ERR_BAD_COMBINATION); opModM(addr, mmx, 0x0F, 0xE7); } +void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, mmx.isXMM() ? 0xD6 : 0x7F); } +void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opModRM(mmx, op, (mmx.getKind() == op.getKind()), op.isMEM(), 0x0F, mmx.isXMM() ? 0x7E : 0x6F); } +void movq2dq(const Xmm& xmm, const Mmx& mmx) { db(0xF3); opModR(xmm, mmx, 0x0F, 0xD6); } +void movsb() { db(0xA4); } +void movsd() { db(0xA5); } +void movsd(const Address& addr, const Xmm& xmm) { db(0xF2); opModM(addr, xmm, 0x0F, 0x11); } +void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF2); } +void movshdup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x16, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movsldup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movss(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x11); } +void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF3); } +void movsw() { db(0x66); db(0xA5); } +void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); } +void movupd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x11); } +void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x66); } +void movups(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x11); } +void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x100); } +void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); } +void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x42, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void mul(const Operand& op) { opR_ModM(op, 0, 4, 0xF6); } +void mulpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x66, isXMM_XMMorMEM); } +void mulps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x100, isXMM_XMMorMEM); } +void mulsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF2, isXMM_XMMorMEM); } +void mulss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF3, isXMM_XMMorMEM); } +void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf6, true); } +void mwait() { db(0x0F); db(0x01); db(0xC9); } +void neg(const Operand& op) { opR_ModM(op, 0, 3, 0xF6); } +void not_(const Operand& op) { opR_ModM(op, 0, 2, 0xF6); } +void or_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x08, 1); } +void or_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x08); } +void orpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x66, isXMM_XMMorMEM); } +void orps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x100, isXMM_XMMorMEM); } +void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, 0x66, NONE, 0x38); } +void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, 0x66, NONE, 0x38); } +void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, 0x66, NONE, 0x38); } +void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); } +void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); } +void packusdw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); } +void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); } +void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); } +void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); } +void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); } +void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); } +void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); } +void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); } +void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); } +void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0f, 0x66, static_cast(imm), 0x3a); } +void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); } +void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); } +void pause() { db(0xF3); db(0x90); } +void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); } +void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); } +void pblendvb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x10, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pblendw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0E, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void pclmulhqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); } +void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); } +void pclmullqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); } +void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); } +void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x44, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); } +void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); } +void pcmpeqq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x29, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); } +void pcmpestri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x61, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpestrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x60, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); } +void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); } +void pcmpgtq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x37, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); } +void pcmpistri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x63, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpistrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x62, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf5, true); } +void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F3 | T_0F38, 0xf5, true); } +void pextrb(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x14, imm); } +void pextrd(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x16, imm); } +void pextrw(const Operand& op, const Mmx& xmm, uint8 imm) { opExt(op, xmm, 0x15, imm, true); } +void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, 0x66, NONE, 0x38); } +void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, 0x66, NONE, 0x38); } +void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, 0x66, NONE, 0x38); } +void phminposuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, 0x66, NONE, 0x38); } +void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, 0x66, NONE, 0x38); } +void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, 0x66, NONE, 0x38); } +void pinsrb(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x20, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x22, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(mmx, op, 0xC4, mmx.isXMM() ? 0x66 : NONE, 0, imm); } +void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, 0x66, NONE, 0x38); } +void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); } +void pmaxsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3C, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3D, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); } +void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); } +void pmaxud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3F, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3E, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x38, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x39, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); } +void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); } +void pminud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3A, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(reg, mmx, 0x0F, 0xD7); } +void pmovsxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x22, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x20, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x25, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x23, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x24, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x31, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x32, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x30, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x35, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x33, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x34, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmuldq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x28, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, 0x66, NONE, 0x38); } +void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); } +void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); } +void pmulld(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); } +void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); } +void popcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xB8); } +void popf() { db(0x9D); } +void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); } +void prefetchnta(const Address& addr) { opModM(addr, Reg32(0), 0x0F, 0x18); } +void prefetcht0(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x18); } +void prefetcht1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x18); } +void prefetcht2(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0x18); } +void prefetchw(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x0D); } +void prefetchwt1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x0D); } +void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); } +void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, 0x66, NONE, 0x38); } +void pshufd(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x66, imm8); } +void pshufhw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF3, imm8); } +void pshuflw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF2, imm8); } +void pshufw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x00, imm8); } +void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, 0x66, NONE, 0x38); } +void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, 0x66, NONE, 0x38); } +void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, 0x66, NONE, 0x38); } +void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); } +void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); } +void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); } +void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); } +void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); } +void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); } +void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); } +void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); } +void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); } +void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); } +void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); } +void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); } +void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); } +void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); } +void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); } +void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); } +void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); } +void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); } +void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); } +void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); } +void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); } +void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); } +void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); } +void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); } +void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); } +void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); } +void ptest(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x17, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); } +void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); } +void punpckhqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6D, 0x66, isXMM_XMMorMEM); } +void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); } +void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); } +void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); } +void punpcklqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6C, 0x66, isXMM_XMMorMEM); } +void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); } +void pushf() { db(0x9C); } +void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); } +void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); } +void rcl(const Operand& op, int imm) { opShift(op, imm, 2); } +void rcpps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0x100, isXMM_XMMorMEM); } +void rcpss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0xF3, isXMM_XMMorMEM); } +void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); } +void rcr(const Operand& op, int imm) { opShift(op, imm, 3); } +void rdmsr() { db(0x0F); db(0x32); } +void rdpmc() { db(0x0F); db(0x33); } +void rdrand(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(6, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdseed(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(7, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdtsc() { db(0x0F); db(0x31); } +void rdtscp() { db(0x0F); db(0x01); db(0xF9); } +void rep() { db(0xF3); } +void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } } +void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 0); } +void rol(const Operand& op, int imm) { opShift(op, imm, 0); } +void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 1); } +void ror(const Operand& op, int imm) { opShift(op, imm, 1); } +void rorx(const Reg32e& r, const Operand& op, uint8 imm) { opGpr(r, op, Reg32e(0, r.getBit()), T_0F3A | T_F2, 0xF0, false, imm); } +void roundpd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x09, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x08, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundsd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0B, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void roundss(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0A, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void rsqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0x100, isXMM_XMMorMEM); } +void rsqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0xF3, isXMM_XMMorMEM); } +void sahf() { db(0x9E); } +void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void sal(const Operand& op, int imm) { opShift(op, imm, 4); } +void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 7); } +void sar(const Operand& op, int imm) { opShift(op, imm, 7); } +void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F3 | T_0F38, 0xf7, false); } +void sbb(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x18, 3); } +void sbb(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x18); } +void scasb() { db(0xAE); } +void scasd() { db(0xAF); } +void scasw() { db(0x66); db(0xAF); } +void seta(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void sete(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void setg(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setna(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setnae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setnb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setnbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setnc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setne(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void setng(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setnge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setnl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setnle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setno(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 1); }//-V524 +void setnp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void setns(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 9); }//-V524 +void setnz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void seto(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 0); }//-V524 +void setp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpo(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void sets(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 8); }//-V524 +void setz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void sfence() { db(0x0F); db(0xAE); db(0xF8); } +void sha1msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC9, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCA, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1nexte(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC8, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1rnds4(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, imm, 0x3A); } +void sha256msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCD, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256rnds2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCB, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void shl(const Operand& op, int imm) { opShift(op, imm, 4); } +void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xA4, &_cl); } +void shld(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xA4); } +void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_66 | T_0F38, 0xf7, false); } +void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 5); } +void shr(const Operand& op, int imm) { opShift(op, imm, 5); } +void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xAC, &_cl); } +void shrd(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xAC); } +void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F2 | T_0F38, 0xf7, false); } +void shufpd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x66, isXMM_XMMorMEM, imm8); } +void shufps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x100, isXMM_XMMorMEM, imm8); } +void sqrtpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x66, isXMM_XMMorMEM); } +void sqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x100, isXMM_XMMorMEM); } +void sqrtsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF2, isXMM_XMMorMEM); } +void sqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF3, isXMM_XMMorMEM); } +void stac() { db(0x0F); db(0x01); db(0xCB); } +void stc() { db(0xF9); } +void std() { db(0xFD); } +void sti() { db(0xFB); } +void stmxcsr(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0xAE); } +void stosb() { db(0xAA); } +void stosd() { db(0xAB); } +void stosw() { db(0x66); db(0xAB); } +void sub(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x28, 5); } +void sub(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x28); } +void subpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x66, isXMM_XMMorMEM); } +void subps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x100, isXMM_XMMorMEM); } +void subsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF2, isXMM_XMMorMEM); } +void subss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF3, isXMM_XMMorMEM); } +void tzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBC); } +void ucomisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x66, isXMM_XMMorMEM); } +void ucomiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x100, isXMM_XMMorMEM); } +void ud2() { db(0x0F); db(0x0B); } +void unpckhpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM); } +void unpckhps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x100, isXMM_XMMorMEM); } +void unpcklpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM); } +void unpcklps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x100, isXMM_XMMorMEM); } +void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); } +void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); } +void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x58); } +void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x58); } +void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0xD0); } +void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0xD0); } +void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDE); } +void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDF); } +void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDC); } +void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDD); } +void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_W0, 0xDB); } +void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0xDF, imm); } +void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); } +void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); } +void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); } +void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); } +void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0D, imm); } +void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0C, imm); } +void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); } +void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); } +void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); } +void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); } +void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); } +void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x18); } +void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); } +void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); } +void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); } +void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); } +void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); } +void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); } +void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); } +void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); } +void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); } +void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); } +void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); } +void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); } +void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); } +void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); } +void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); } +void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); } +void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); } +void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); } +void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); } +void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); } +void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); } +void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); } +void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); } +void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); } +void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); } +void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); } +void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); } +void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); } +void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); } +void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); } +void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); } +void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); } +void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); } +void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); } +void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); } +void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); } +void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); } +void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); } +void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); } +void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); } +void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); } +void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); } +void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); } +void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); } +void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); } +void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); } +void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); } +void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); } +void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); } +void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); } +void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); } +void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); } +void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); } +void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); } +void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); } +void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); } +void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); } +void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); } +void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); } +void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); } +void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); } +void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); } +void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); } +void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); } +void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); } +void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); } +void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); } +void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); } +void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); } +void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); } +void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); } +void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); } +void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); } +void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); } +void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); } +void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); } +void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); } +void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); } +void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); } +void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); } +void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); } +void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); } +void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); } +void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); } +void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); } +void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); } +void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); } +void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); } +void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); } +void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); } +void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); } +void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); } +void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); } +void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); } +void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); } +void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); } +void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); } +void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); } +void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); } +void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); } +void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); } +void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); } +void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); } +void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); } +void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); } +void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); } +void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); } +void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); } +void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); } +void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); } +void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); } +void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); } +void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xC2, imm); } +void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_YMM, 0xC2, imm); } +void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F2 | T_0F, 0xC2, imm); } +void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0xC2, imm); } +void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); } +void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); } +void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); } +void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); } +void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); } +void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); } +void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); } +void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); } +void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); } +void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); } +void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); } +void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); } +void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); } +void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); } +void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); } +void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); } +void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2F); } +void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2F); } +void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); } +void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); } +void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); } +void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); } +void vcvtps2ph(const Operand& op, const Xmm& x, uint8 imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x1D, imm); } +void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); } +void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x5A); } +void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x5A); } +void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); } +void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX | T_SAE_Z | T_B32, 0x5B); } +void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); } +void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); } +void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); } +void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); } +void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5E); } +void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5E); } +void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x41, imm); } +void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x40, imm); } +void vextractf128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); } +void vextracti128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); } +void vextractps(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); } +void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x98); } +void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x98); } +void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x99); } +void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x99); } +void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA8); } +void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA8); } +void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB8); } +void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB8); } +void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xB9); } +void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xB9); } +void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x96); } +void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x96); } +void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA6); } +void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA6); } +void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB6); } +void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB6); } +void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9A); } +void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9A); } +void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAA); } +void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAA); } +void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBA); } +void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBA); } +void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBB); } +void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBB); } +void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x97); } +void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x97); } +void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA7); } +void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA7); } +void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB7); } +void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB7); } +void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9C); } +void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9C); } +void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAC); } +void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAC); } +void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBC); } +void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBC); } +void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBD); } +void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBD); } +void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9E); } +void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9E); } +void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAE); } +void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAE); } +void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBE); } +void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBE); } +void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBF); } +void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBF); } +void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); } +void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); } +void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); } +void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); } +void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCF, imm); } +void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCE, imm); } +void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_SAE_Z, 0xCF); } +void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7C); } +void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7C); } +void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7D); } +void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7D); } +void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); } +void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); } +void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_W0 | T_EW0 | T_EVEX, 0x21, imm); } +void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); } +void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); } +void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); } +void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); } +void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); } +void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); } +void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); } +void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); } +void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); } +void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5F); } +void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5F); } +void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); } +void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); } +void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5D); } +void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5D); } +void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x28); } +void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x28); } +void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); } +void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); } +void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP | T_F2 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_X | T_ER_Y | T_ER_Z, 0x12); } +void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_YMM, 0x7F); } +void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_YMM, 0x6F); } +void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3 | T_0F | T_YMM, 0x7F); } +void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM, 0x6F); } +void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); } +void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x17); } +void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x16); } +void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x17); } +void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x16); } +void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); } +void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x13); } +void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x12); } +void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x13); } +void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12); } +void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); } +void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); } +void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); } +void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); } +void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); } +void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); } +void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); } +void vmovq(const Xmm& x, const Address& addr) { int type, code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); } +void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); } +void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_M_K, 0x11); } +void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x16); } +void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x12); } +void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_M_K, 0x11); } +void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x10); } +void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x10); } +void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x42, imm); } +void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); } +void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); } +void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x59); } +void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x59); } +void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); } +void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); } +void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1C); } +void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x1E); } +void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1D); } +void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6B); } +void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x63); } +void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x2B); } +void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x67); } +void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFC); } +void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFE); } +void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xD4); } +void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEC); } +void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xED); } +void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDC); } +void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDD); } +void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFD); } +void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_YMM | T_EVEX, 0x0F, imm); } +void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDB); } +void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDF); } +void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE0); } +void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE3); } +void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x02, imm); } +void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); } +void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0E, imm); } +void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x78); } +void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x58); } +void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX, 0x59); } +void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x79); } +void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM | T_EVEX, 0x44, imm); } +void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x74); } +void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x76); } +void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x29); } +void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x75); } +void vpcmpestri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x61, imm); } +void vpcmpestrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x60, imm); } +void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x64); } +void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x66); } +void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x37); } +void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); } +void vpcmpistri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); } +void vpcmpistrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); } +void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); } +void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); } +void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); } +void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x0D); } +void vpermilpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_EVEX | T_B64, 0x05, imm); } +void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x0C); } +void vpermilps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_EVEX | T_B32, 0x04, imm); } +void vpermpd(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x01, imm); } +void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x16); } +void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x16); } +void vpermq(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x00, imm); } +void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x36); } +void vpextrb(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); } +void vpextrd(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); } +void vpextrq(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); } +void vpextrw(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } } +void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); } +void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); } +void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); } +void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); } +void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x02); } +void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x03); } +void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x01); } +void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38, 0x41); } +void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x06); } +void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x07); } +void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x05); } +void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); } +void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); } +void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); } +void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); } +void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x04); } +void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF5); } +void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); } +void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); } +void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); } +void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); } +void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3C); } +void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3D); } +void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEE); } +void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDE); } +void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3F); } +void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3E); } +void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x38); } +void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x39); } +void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEA); } +void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDA); } +void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3B); } +void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A); } +void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); } +void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21); } +void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x22); } +void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x20); } +void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x25); } +void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x23); } +void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x24); } +void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x31); } +void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x32); } +void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x30); } +void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x35); } +void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x33); } +void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x34); } +void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x28); } +void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x0B); } +void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE4); } +void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE5); } +void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x40); } +void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD5); } +void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xF4); } +void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEB); } +void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF6); } +void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x00); } +void vpshufd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x70, imm); } +void vpshufhw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpshuflw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F2 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x08); } +void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x0A); } +void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x09); } +void vpslld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xF2); } +void vpslldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xF3); } +void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x47); } +void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x47); } +void vpsllw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xF1); } +void vpsrad(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xE2); } +void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x46); } +void vpsraw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xE1); } +void vpsrld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xD2); } +void vpsrldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xD3); } +void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x45); } +void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x45); } +void vpsrlw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xD1); } +void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF8); } +void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFA); } +void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xFB); } +void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE8); } +void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE9); } +void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD8); } +void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD9); } +void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF9); } +void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x17); } +void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x68); } +void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6A); } +void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6D); } +void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x69); } +void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x60); } +void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x62); } +void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6C); } +void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x61); } +void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEF); } +void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x53); } +void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x53); } +void vroundpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x09, imm); } +void vroundps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x08, imm); } +void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0B, imm); } +void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0A, imm); } +void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x52); } +void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x52); } +void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xC6, imm); } +void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xC6, imm); } +void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x51); } +void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x51); } +void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x51); } +void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_ER_X, 0x51); } +void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); } +void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); } +void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); } +void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5C); } +void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5C); } +void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0F); } +void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0E); } +void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2E); } +void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2E); } +void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x15); } +void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x15); } +void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x14); } +void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x14); } +void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); } +void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); } +void vzeroall() { db(0xC5); db(0xFC); db(0x77); } +void vzeroupper() { db(0xC5); db(0xF8); db(0x77); } +void wait() { db(0x9B); } +void wbinvd() { db(0x0F); db(0x09); } +void wrmsr() { db(0x0F); db(0x30); } +void xadd(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xC0 | (reg.isBit(8) ? 0 : 1)); } +void xgetbv() { db(0x0F); db(0x01); db(0xD0); } +void xlatb() { db(0xD7); } +void xor_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x30, 6); } +void xor_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x30); } +void xorpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x66, isXMM_XMMorMEM); } +void xorps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x100, isXMM_XMMorMEM); } +#ifdef XBYAK_ENABLE_OMITTED_OPERAND +void vblendpd(const Xmm& x, const Operand& op, uint8 imm) { vblendpd(x, x, op, imm); } +void vblendps(const Xmm& x, const Operand& op, uint8 imm) { vblendps(x, x, op, imm); } +void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); } +void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); } +void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); } +void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); } +void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); } +void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); } +void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); } +void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); } +void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); } +void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); } +void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); } +void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); } +void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); } +void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); } +void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); } +void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); } +void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); } +void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); } +void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); } +void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); } +void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); } +void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); } +void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); } +void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); } +void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); } +void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); } +void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); } +void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); } +void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); } +void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); } +void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); } +void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); } +void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); } +void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); } +void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); } +void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); } +void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); } +void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); } +void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); } +void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); } +void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); } +void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); } +void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); } +void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); } +void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); } +void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); } +void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); } +void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); } +void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); } +void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); } +void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); } +void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); } +void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); } +void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); } +void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); } +void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); } +void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); } +void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); } +void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); } +void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); } +void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); } +void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); } +void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); } +void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); } +void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); } +void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); } +void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); } +void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); } +void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); } +void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); } +void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); } +void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); } +void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); } +void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); } +void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); } +void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); } +void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); } +void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); } +void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); } +void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); } +void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); } +void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); } +void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); } +void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); } +void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); } +void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); } +void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); } +void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); } +void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); } +void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); } +void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); } +void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); } +void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); } +void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); } +void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); } +void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); } +void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); } +void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); } +void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); } +void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); } +void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); } +void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); } +void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); } +void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); } +void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); } +void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); } +void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); } +void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); } +void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); } +void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); } +void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); } +void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); } +void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); } +void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); } +void vcmppd(const Xmm& x, const Operand& op, uint8 imm) { vcmppd(x, x, op, imm); } +void vcmpps(const Xmm& x, const Operand& op, uint8 imm) { vcmpps(x, x, op, imm); } +void vcmpsd(const Xmm& x, const Operand& op, uint8 imm) { vcmpsd(x, x, op, imm); } +void vcmpss(const Xmm& x, const Operand& op, uint8 imm) { vcmpss(x, x, op, imm); } +void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); } +void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); } +void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); } +void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); } +void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); } +void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); } +void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); } +void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); } +void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); } +void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); } +void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); } +void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); } +void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); } +void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); } +void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); } +void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); } +void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); } +void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); } +void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); } +void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); } +void vdppd(const Xmm& x, const Operand& op, uint8 imm) { vdppd(x, x, op, imm); } +void vdpps(const Xmm& x, const Operand& op, uint8 imm) { vdpps(x, x, op, imm); } +void vinsertps(const Xmm& x, const Operand& op, uint8 imm) { vinsertps(x, x, op, imm); } +void vmpsadbw(const Xmm& x, const Operand& op, uint8 imm) { vmpsadbw(x, x, op, imm); } +void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); } +void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); } +void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); } +void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); } +void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); } +void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); } +void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); } +void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); } +void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); } +void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); } +void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); } +void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); } +void vpalignr(const Xmm& x, const Operand& op, uint8 imm) { vpalignr(x, x, op, imm); } +void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); } +void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); } +void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); } +void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); } +void vpblendd(const Xmm& x, const Operand& op, uint8 imm) { vpblendd(x, x, op, imm); } +void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); } +void vpblendw(const Xmm& x, const Operand& op, uint8 imm) { vpblendw(x, x, op, imm); } +void vpclmulqdq(const Xmm& x, const Operand& op, uint8 imm) { vpclmulqdq(x, x, op, imm); } +void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); } +void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); } +void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); } +void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); } +void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); } +void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); } +void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); } +void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); } +void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); } +void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); } +void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); } +void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); } +void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); } +void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); } +void vpinsrb(const Xmm& x, const Operand& op, uint8 imm) { vpinsrb(x, x, op, imm); } +void vpinsrd(const Xmm& x, const Operand& op, uint8 imm) { vpinsrd(x, x, op, imm); } +void vpinsrq(const Xmm& x, const Operand& op, uint8 imm) { vpinsrq(x, x, op, imm); } +void vpinsrw(const Xmm& x, const Operand& op, uint8 imm) { vpinsrw(x, x, op, imm); } +void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); } +void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); } +void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); } +void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); } +void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); } +void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); } +void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); } +void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); } +void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); } +void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); } +void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); } +void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); } +void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); } +void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); } +void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); } +void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); } +void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); } +void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); } +void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); } +void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); } +void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); } +void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); } +void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); } +void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); } +void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); } +void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); } +void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); } +void vpslld(const Xmm& x, uint8 imm) { vpslld(x, x, imm); } +void vpslldq(const Xmm& x, uint8 imm) { vpslldq(x, x, imm); } +void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); } +void vpsllq(const Xmm& x, uint8 imm) { vpsllq(x, x, imm); } +void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); } +void vpsllw(const Xmm& x, uint8 imm) { vpsllw(x, x, imm); } +void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); } +void vpsrad(const Xmm& x, uint8 imm) { vpsrad(x, x, imm); } +void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); } +void vpsraw(const Xmm& x, uint8 imm) { vpsraw(x, x, imm); } +void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); } +void vpsrld(const Xmm& x, uint8 imm) { vpsrld(x, x, imm); } +void vpsrldq(const Xmm& x, uint8 imm) { vpsrldq(x, x, imm); } +void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); } +void vpsrlq(const Xmm& x, uint8 imm) { vpsrlq(x, x, imm); } +void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); } +void vpsrlw(const Xmm& x, uint8 imm) { vpsrlw(x, x, imm); } +void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); } +void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); } +void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); } +void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); } +void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); } +void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); } +void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); } +void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); } +void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); } +void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); } +void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); } +void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); } +void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); } +void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); } +void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); } +void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); } +void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); } +void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); } +void vroundsd(const Xmm& x, const Operand& op, uint8 imm) { vroundsd(x, x, op, imm); } +void vroundss(const Xmm& x, const Operand& op, uint8 imm) { vroundss(x, x, op, imm); } +void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); } +void vshufpd(const Xmm& x, const Operand& op, uint8 imm) { vshufpd(x, x, op, imm); } +void vshufps(const Xmm& x, const Operand& op, uint8 imm) { vshufps(x, x, op, imm); } +void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); } +void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); } +void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); } +void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); } +void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); } +void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); } +#endif +#ifdef XBYAK64 +void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void cdqe() { db(0x48); db(0x98); } +void cqo() { db(0x48); db(0x99); } +void cmpsq() { db(0x48); db(0xA7); } +void movsq() { db(0x48); db(0xA5); } +void scasq() { db(0x48); db(0xAF); } +void stosq() { db(0x48); db(0xAB); } +void cmpxchg16b(const Address& addr) { opModM(addr, Reg64(1), 0x0F, 0xC7); } +void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) throw Error(ERR_BAD_COMBINATION); opModRM(reg, op, op.isREG(), op.isMEM(), 0x63); } +void pextrq(const Operand& op, const Xmm& xmm, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x16, 0x66, 0, imm, 0x3A); } +void pinsrq(const Xmm& xmm, const Operand& op, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x22, 0x66, 0, imm, 0x3A); } +void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); } +void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); } +void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); } +void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); } +void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); } +void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); } +#else +void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void aaa() { db(0x37); } +void aad() { db(0xD5); db(0x0A); } +void aam() { db(0xD4); db(0x0A); } +void aas() { db(0x3F); } +void daa() { db(0x27); } +void das() { db(0x2F); } +void popad() { db(0x61); } +void popfd() { db(0x9D); } +void pusha() { db(0x60); } +void pushad() { db(0x60); } +void pushfd() { db(0x9C); } +void popa() { db(0x61); } +#endif +#ifndef XBYAK_NO_OP_NAMES +void and(const Operand& op1, const Operand& op2) { and_(op1, op2); } +void and(const Operand& op, uint32 imm) { and_(op, imm); } +void or(const Operand& op1, const Operand& op2) { or_(op1, op2); } +void or(const Operand& op, uint32 imm) { or_(op, imm); } +void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); } +void xor(const Operand& op, uint32 imm) { xor_(op, imm); } +void not(const Operand& op) { not_(op); } +#endif +#ifndef XBYAK_DISABLE_AVX512 +void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); } +void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); } +void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); } +void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); } +void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); } +void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); } +void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); } +void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); } +void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); } +void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); } +void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); } +void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); } +void kmovb(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W0, 0x91); } +void kmovb(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W0, 0x90); } +void kmovb(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_66 | T_W0, 0x92); } +void kmovb(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_66 | T_W0, 0x93); } +void kmovd(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W1, 0x91); } +void kmovd(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W1, 0x90); } +void kmovd(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W0, 0x92); } +void kmovd(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W0, 0x93); } +void kmovq(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W1, 0x91); } +void kmovq(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W1, 0x90); } +void kmovw(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W0, 0x91); } +void kmovw(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W0, 0x90); } +void kmovw(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_W0, 0x92); } +void kmovw(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_W0, 0x93); } +void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); } +void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); } +void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); } +void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); } +void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); } +void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); } +void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); } +void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); } +void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); } +void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); } +void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); } +void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); } +void kshiftlb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); } +void kshiftld(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); } +void kshiftlq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); } +void kshiftlw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); } +void kshiftrb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); } +void kshiftrd(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); } +void kshiftrq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); } +void kshiftrw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); } +void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); } +void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); } +void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); } +void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); } +void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); } +void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); } +void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); } +void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); } +void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); } +void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); } +void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); } +void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); } +void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); } +void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); } +void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); } +void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); } +void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); } +void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); } +void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); } +void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x65); } +void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x65); } +void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); } +void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); } +void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); } +void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); } +void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); } +void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); } +void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); } +void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); } +void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); } +void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); } +void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N8 | T_F2 | T_0F | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N4 | T_F3 | T_0F | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x63); } +void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x63); } +void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7B); } +void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x79); } +void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x79); } +void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x7B); } +void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x79); } +void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x79); } +void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0xE6); } +void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5B); } +void vcvtsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvtss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_SAE_Z, 0x78); } +void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x78); } +void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x7A); } +void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x78); } +void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x78); } +void vcvttsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvttss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_F3 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0x7A); } +void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x7A); } +void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2 | T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x7A); } +void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x42, imm); } +void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); } +void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); } +void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x88); } +void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x88); } +void vextractf32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextractf64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextracti32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vextracti64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x54, imm); } +void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x54, imm); } +void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfpclasspd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); } +void vfpclassps(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); } +void vfpclasssd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); } +void vfpclassss(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); } +void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 1); } +void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 0); } +void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 0); } +void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 2); } +void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x42); } +void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x42); } +void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetmantpd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x26, imm); } +void vgetmantps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x26, imm); } +void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); } +void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); } +void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); } +void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDB); } +void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDF); } +void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDF); } +void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDB); } +void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x66); } +void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x64); } +void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x64); } +void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x66); } +void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7A); } +void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7C); } +void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); } +void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); } +void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7B); } +void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1F, imm); } +void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x74); } +void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x29); } +void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x75); } +void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x64); } +void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x66); } +void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x37); } +void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x65); } +void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1F, imm); } +void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1E, imm); } +void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1E, imm); } +void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xC4); } +void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xC4); } +void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50); } +void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x51); } +void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x52); } +void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x53); } +void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x77); } +void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x77); } +void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x76); } +void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7E); } +void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7F); } +void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7F); } +void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7E); } +void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 0); } +void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 1); } +void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 2); } +void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 0); } +void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x44); } +void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x44); } +void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB5); } +void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB4); } +void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3D); } +void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3F); } +void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x39); } +void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3B); } +void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); } +void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); } +void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x31, false); } +void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x33, true); } +void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); } +void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); } +void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); } +void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); } +void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); } +void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x32, false); } +void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x35, true); } +void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x34, false); } +void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x21, false); } +void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x23, true); } +void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x22, false); } +void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x25, true); } +void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x24, false); } +void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x20, true); } +void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x11, false); } +void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x13, true); } +void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x12, false); } +void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x15, true); } +void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x14, false); } +void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x10, true); } +void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); } +void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x30, true); } +void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x40); } +void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x83); } +void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x55); } +void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x55); } +void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEB); } +void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEB); } +void vprold(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprolq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x15); } +void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x15); } +void vprord(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprorq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x14); } +void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x14); } +void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 0); } +void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 1); } +void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 2); } +void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 0); } +void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71, imm); } +void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71, imm); } +void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71); } +void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71); } +void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70); } +void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70, imm); } +void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73, imm); } +void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73, imm); } +void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73); } +void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73); } +void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72); } +void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72, imm); } +void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); } +void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x12); } +void vpsraq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX, 0xE2); } +void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x46); } +void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x11); } +void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x10); } +void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x25, imm); } +void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x25, imm); } +void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEF); } +void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEF); } +void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x50, imm); } +void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50, imm); } +void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4C); } +void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4C); } +void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX, 0x4D); } +void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX, 0x4D); } +void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); } +void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); } +void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vreducepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x56, imm); } +void vreduceps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x56, imm); } +void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vrndscalepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x09, imm); } +void vrndscaleps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x08, imm); } +void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_MUST_EVEX, 0x0B, imm); } +void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_MUST_EVEX, 0x0A, imm); } +void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4E); } +void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4E); } +void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); } +void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); } +void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x2C); } +void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x2C); } +void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 1); } +void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 0); } +void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 0); } +void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 2); } +void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); } +void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); } +void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); } +void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); } +#ifdef XBYAK64 +void kmovq(const Opmask& k, const Reg64& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W1, 0x92); } +void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W1, 0x93); } +void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); } +#endif +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h new file mode 100644 index 0000000000..8ef076e680 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h @@ -0,0 +1,772 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#ifndef XBYAK_XBYAK_UTIL_H_ +#define XBYAK_XBYAK_UTIL_H_ + +/** + utility class and functions for Xbyak + Xbyak::util::Clock ; rdtsc timer + Xbyak::util::Cpu ; detect CPU + @note this header is UNDER CONSTRUCTION! +*/ +#include "xbyak.h" + +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + #define XBYAK_INTEL_CPU_SPECIFIC +#endif + +#ifdef XBYAK_INTEL_CPU_SPECIFIC +#ifdef _MSC_VER + #if (_MSC_VER < 1400) && defined(XBYAK32) + static inline __declspec(naked) void __cpuid(int[4], int) + { + __asm { + push ebx + push esi + mov eax, dword ptr [esp + 4 * 2 + 8] // eaxIn + cpuid + mov esi, dword ptr [esp + 4 * 2 + 4] // data + mov dword ptr [esi], eax + mov dword ptr [esi + 4], ebx + mov dword ptr [esi + 8], ecx + mov dword ptr [esi + 12], edx + pop esi + pop ebx + ret + } + } + #else + #include // for __cpuid + #endif +#else + #ifndef __GNUC_PREREQ + #define __GNUC_PREREQ(major, minor) ((((__GNUC__) << 16) + (__GNUC_MINOR__)) >= (((major) << 16) + (minor))) + #endif + #if __GNUC_PREREQ(4, 3) && !defined(__APPLE__) + #include + #else + #if defined(__APPLE__) && defined(XBYAK32) // avoid err : can't find a register in class `BREG' while reloading `asm' + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #else + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #endif + #endif +#endif +#endif + +namespace Xbyak { namespace util { + +typedef enum { + SmtLevel = 1, + CoreLevel = 2 +} IntelCpuTopologyLevel; + +/** + CPU detection class +*/ +class Cpu { + uint64 type_; + //system topology + bool x2APIC_supported_; + static const size_t maxTopologyLevels = 2; + unsigned int numCores_[maxTopologyLevels]; + + static const unsigned int maxNumberCacheLevels = 10; + unsigned int dataCacheSize_[maxNumberCacheLevels]; + unsigned int coresSharignDataCache_[maxNumberCacheLevels]; + unsigned int dataCacheLevels_; + + unsigned int get32bitAsBE(const char *x) const + { + return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); + } + unsigned int mask(int n) const + { + return (1U << n) - 1; + } + void setFamily() + { + unsigned int data[4] = {}; + getCpuid(1, data); + stepping = data[0] & mask(4); + model = (data[0] >> 4) & mask(4); + family = (data[0] >> 8) & mask(4); + // type = (data[0] >> 12) & mask(2); + extModel = (data[0] >> 16) & mask(4); + extFamily = (data[0] >> 20) & mask(8); + if (family == 0x0f) { + displayFamily = family + extFamily; + } else { + displayFamily = family; + } + if (family == 6 || family == 0x0f) { + displayModel = (extModel << 4) + model; + } else { + displayModel = model; + } + } + unsigned int extractBit(unsigned int val, unsigned int base, unsigned int end) + { + return (val >> base) & ((1u << (end - base)) - 1); + } + void setNumCores() + { + if ((type_ & tINTEL) == 0) return; + + unsigned int data[4] = {}; + + /* CAUTION: These numbers are configuration as shipped by Intel. */ + getCpuidEx(0x0, 0, data); + if (data[0] >= 0xB) { + /* + if leaf 11 exists(x2APIC is supported), + we use it to get the number of smt cores and cores on socket + + leaf 0xB can be zeroed-out by a hypervisor + */ + x2APIC_supported_ = true; + for (unsigned int i = 0; i < maxTopologyLevels; i++) { + getCpuidEx(0xB, i, data); + IntelCpuTopologyLevel level = (IntelCpuTopologyLevel)extractBit(data[2], 8, 15); + if (level == SmtLevel || level == CoreLevel) { + numCores_[level - 1] = extractBit(data[1], 0, 15); + } + } + } else { + /* + Failed to deremine num of cores without x2APIC support. + TODO: USE initial APIC ID to determine ncores. + */ + numCores_[SmtLevel - 1] = 0; + numCores_[CoreLevel - 1] = 0; + } + + } + void setCacheHierarchy() + { + if ((type_ & tINTEL) == 0) return; + const unsigned int NO_CACHE = 0; + const unsigned int DATA_CACHE = 1; +// const unsigned int INSTRUCTION_CACHE = 2; + const unsigned int UNIFIED_CACHE = 3; + unsigned int smt_width = 0; + unsigned int logical_cores = 0; + unsigned int data[4] = {}; + + if (x2APIC_supported_) { + smt_width = numCores_[0]; + logical_cores = numCores_[1]; + } + + /* + Assumptions: + the first level of data cache is not shared (which is the + case for every existing architecture) and use this to + determine the SMT width for arch not supporting leaf 11. + when leaf 4 reports a number of core less than numCores_ + on socket reported by leaf 11, then it is a correct number + of cores not an upperbound. + */ + for (int i = 0; dataCacheLevels_ < maxNumberCacheLevels; i++) { + getCpuidEx(0x4, i, data); + unsigned int cacheType = extractBit(data[0], 0, 4); + if (cacheType == NO_CACHE) break; + if (cacheType == DATA_CACHE || cacheType == UNIFIED_CACHE) { + unsigned int actual_logical_cores = extractBit(data[0], 14, 25) + 1; + if (logical_cores != 0) { // true only if leaf 0xB is supported and valid + actual_logical_cores = (std::min)(actual_logical_cores, logical_cores); + } + assert(actual_logical_cores != 0); + dataCacheSize_[dataCacheLevels_] = + (extractBit(data[1], 22, 31) + 1) + * (extractBit(data[1], 12, 21) + 1) + * (extractBit(data[1], 0, 11) + 1) + * (data[2] + 1); + if (cacheType == DATA_CACHE && smt_width == 0) smt_width = actual_logical_cores; + assert(smt_width != 0); + // FIXME: check and fix number of cores sharing L3 cache for different configurations + // (HT-, 2 sockets), (HT-, 1 socket), (HT+, 2 sockets), (HT+, 1 socket) + coresSharignDataCache_[dataCacheLevels_] = (std::max)(actual_logical_cores / smt_width, 1u); + dataCacheLevels_++; + } + } + } + +public: + int model; + int family; + int stepping; + int extModel; + int extFamily; + int displayFamily; // family + extFamily + int displayModel; // model + extModel + + unsigned int getNumCores(IntelCpuTopologyLevel level) { + if (level != SmtLevel && level != CoreLevel) throw Error(ERR_BAD_PARAMETER); + if (!x2APIC_supported_) throw Error(ERR_X2APIC_IS_NOT_SUPPORTED); + return (level == CoreLevel) + ? numCores_[level - 1] / numCores_[SmtLevel - 1] + : numCores_[level - 1]; + } + + unsigned int getDataCacheLevels() const { return dataCacheLevels_; } + unsigned int getCoresSharingDataCache(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return coresSharignDataCache_[i]; + } + unsigned int getDataCacheSize(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return dataCacheSize_[i]; + } + + /* + data[] = { eax, ebx, ecx, edx } + */ + static inline void getCpuid(unsigned int eaxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuid(reinterpret_cast(data), eaxIn); + #else + __cpuid(eaxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)data; +#endif + } + static inline void getCpuidEx(unsigned int eaxIn, unsigned int ecxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuidex(reinterpret_cast(data), eaxIn, ecxIn); + #else + __cpuid_count(eaxIn, ecxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)ecxIn; + (void)data; +#endif + } + static inline uint64 getXfeature() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return _xgetbv(0); + #else + unsigned int eax, edx; + // xgetvb is not support on gcc 4.2 +// __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0)); + __asm__ volatile(".byte 0x0f, 0x01, 0xd0" : "=a"(eax), "=d"(edx) : "c"(0)); + return ((uint64)edx << 32) | eax; + #endif +#else + return 0; +#endif + } + typedef uint64 Type; + + static const Type NONE = 0; + static const Type tMMX = 1 << 0; + static const Type tMMX2 = 1 << 1; + static const Type tCMOV = 1 << 2; + static const Type tSSE = 1 << 3; + static const Type tSSE2 = 1 << 4; + static const Type tSSE3 = 1 << 5; + static const Type tSSSE3 = 1 << 6; + static const Type tSSE41 = 1 << 7; + static const Type tSSE42 = 1 << 8; + static const Type tPOPCNT = 1 << 9; + static const Type tAESNI = 1 << 10; + static const Type tSSE5 = 1 << 11; + static const Type tOSXSAVE = 1 << 12; + static const Type tPCLMULQDQ = 1 << 13; + static const Type tAVX = 1 << 14; + static const Type tFMA = 1 << 15; + + static const Type t3DN = 1 << 16; + static const Type tE3DN = 1 << 17; + static const Type tSSE4a = 1 << 18; + static const Type tRDTSCP = 1 << 19; + static const Type tAVX2 = 1 << 20; + static const Type tBMI1 = 1 << 21; // andn, bextr, blsi, blsmsk, blsr, tzcnt + static const Type tBMI2 = 1 << 22; // bzhi, mulx, pdep, pext, rorx, sarx, shlx, shrx + static const Type tLZCNT = 1 << 23; + + static const Type tINTEL = 1 << 24; + static const Type tAMD = 1 << 25; + + static const Type tENHANCED_REP = 1 << 26; // enhanced rep movsb/stosb + static const Type tRDRAND = 1 << 27; + static const Type tADX = 1 << 28; // adcx, adox + static const Type tRDSEED = 1 << 29; // rdseed + static const Type tSMAP = 1 << 30; // stac + static const Type tHLE = uint64(1) << 31; // xacquire, xrelease, xtest + static const Type tRTM = uint64(1) << 32; // xbegin, xend, xabort + static const Type tF16C = uint64(1) << 33; // vcvtph2ps, vcvtps2ph + static const Type tMOVBE = uint64(1) << 34; // mobve + static const Type tAVX512F = uint64(1) << 35; + static const Type tAVX512DQ = uint64(1) << 36; + static const Type tAVX512_IFMA = uint64(1) << 37; + static const Type tAVX512IFMA = tAVX512_IFMA; + static const Type tAVX512PF = uint64(1) << 38; + static const Type tAVX512ER = uint64(1) << 39; + static const Type tAVX512CD = uint64(1) << 40; + static const Type tAVX512BW = uint64(1) << 41; + static const Type tAVX512VL = uint64(1) << 42; + static const Type tAVX512_VBMI = uint64(1) << 43; + static const Type tAVX512VBMI = tAVX512_VBMI; // changed by Intel's manual + static const Type tAVX512_4VNNIW = uint64(1) << 44; + static const Type tAVX512_4FMAPS = uint64(1) << 45; + static const Type tPREFETCHWT1 = uint64(1) << 46; + static const Type tPREFETCHW = uint64(1) << 47; + static const Type tSHA = uint64(1) << 48; + static const Type tMPX = uint64(1) << 49; + static const Type tAVX512_VBMI2 = uint64(1) << 50; + static const Type tGFNI = uint64(1) << 51; + static const Type tVAES = uint64(1) << 52; + static const Type tVPCLMULQDQ = uint64(1) << 53; + static const Type tAVX512_VNNI = uint64(1) << 54; + static const Type tAVX512_BITALG = uint64(1) << 55; + static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56; + + Cpu() + : type_(NONE) + , x2APIC_supported_(false) + , numCores_() + , dataCacheSize_() + , coresSharignDataCache_() + , dataCacheLevels_(0) + { + unsigned int data[4] = {}; + const unsigned int& EAX = data[0]; + const unsigned int& EBX = data[1]; + const unsigned int& ECX = data[2]; + const unsigned int& EDX = data[3]; + getCpuid(0, data); + const unsigned int maxNum = EAX; + static const char intel[] = "ntel"; + static const char amd[] = "cAMD"; + if (ECX == get32bitAsBE(amd)) { + type_ |= tAMD; + getCpuid(0x80000001, data); + if (EDX & (1U << 31)) type_ |= t3DN; + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 30)) type_ |= tE3DN; + if (EDX & (1U << 22)) type_ |= tMMX2; + if (EDX & (1U << 27)) type_ |= tRDTSCP; + } + if (ECX == get32bitAsBE(intel)) { + type_ |= tINTEL; + getCpuid(0x80000001, data); + if (EDX & (1U << 27)) type_ |= tRDTSCP; + if (ECX & (1U << 5)) type_ |= tLZCNT; + if (ECX & (1U << 8)) type_ |= tPREFETCHW; + } + getCpuid(1, data); + if (ECX & (1U << 0)) type_ |= tSSE3; + if (ECX & (1U << 9)) type_ |= tSSSE3; + if (ECX & (1U << 19)) type_ |= tSSE41; + if (ECX & (1U << 20)) type_ |= tSSE42; + if (ECX & (1U << 22)) type_ |= tMOVBE; + if (ECX & (1U << 23)) type_ |= tPOPCNT; + if (ECX & (1U << 25)) type_ |= tAESNI; + if (ECX & (1U << 1)) type_ |= tPCLMULQDQ; + if (ECX & (1U << 27)) type_ |= tOSXSAVE; + if (ECX & (1U << 30)) type_ |= tRDRAND; + if (ECX & (1U << 29)) type_ |= tF16C; + + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 23)) type_ |= tMMX; + if (EDX & (1U << 25)) type_ |= tMMX2 | tSSE; + if (EDX & (1U << 26)) type_ |= tSSE2; + + if (type_ & tOSXSAVE) { + // check XFEATURE_ENABLED_MASK[2:1] = '11b' + uint64 bv = getXfeature(); + if ((bv & 6) == 6) { + if (ECX & (1U << 28)) type_ |= tAVX; + if (ECX & (1U << 12)) type_ |= tFMA; + if (((bv >> 5) & 7) == 7) { + getCpuidEx(7, 0, data); + if (EBX & (1U << 16)) type_ |= tAVX512F; + if (type_ & tAVX512F) { + if (EBX & (1U << 17)) type_ |= tAVX512DQ; + if (EBX & (1U << 21)) type_ |= tAVX512_IFMA; + if (EBX & (1U << 26)) type_ |= tAVX512PF; + if (EBX & (1U << 27)) type_ |= tAVX512ER; + if (EBX & (1U << 28)) type_ |= tAVX512CD; + if (EBX & (1U << 30)) type_ |= tAVX512BW; + if (EBX & (1U << 31)) type_ |= tAVX512VL; + if (ECX & (1U << 1)) type_ |= tAVX512_VBMI; + if (ECX & (1U << 6)) type_ |= tAVX512_VBMI2; + if (ECX & (1U << 8)) type_ |= tGFNI; + if (ECX & (1U << 9)) type_ |= tVAES; + if (ECX & (1U << 10)) type_ |= tVPCLMULQDQ; + if (ECX & (1U << 11)) type_ |= tAVX512_VNNI; + if (ECX & (1U << 12)) type_ |= tAVX512_BITALG; + if (ECX & (1U << 14)) type_ |= tAVX512_VPOPCNTDQ; + if (EDX & (1U << 2)) type_ |= tAVX512_4VNNIW; + if (EDX & (1U << 3)) type_ |= tAVX512_4FMAPS; + } + } + } + } + if (maxNum >= 7) { + getCpuidEx(7, 0, data); + if (type_ & tAVX && (EBX & (1U << 5))) type_ |= tAVX2; + if (EBX & (1U << 3)) type_ |= tBMI1; + if (EBX & (1U << 8)) type_ |= tBMI2; + if (EBX & (1U << 9)) type_ |= tENHANCED_REP; + if (EBX & (1U << 18)) type_ |= tRDSEED; + if (EBX & (1U << 19)) type_ |= tADX; + if (EBX & (1U << 20)) type_ |= tSMAP; + if (EBX & (1U << 4)) type_ |= tHLE; + if (EBX & (1U << 11)) type_ |= tRTM; + if (EBX & (1U << 14)) type_ |= tMPX; + if (EBX & (1U << 29)) type_ |= tSHA; + if (ECX & (1U << 0)) type_ |= tPREFETCHWT1; + } + setFamily(); + setNumCores(); + setCacheHierarchy(); + } + void putFamily() const + { + printf("family=%d, model=%X, stepping=%d, extFamily=%d, extModel=%X\n", + family, model, stepping, extFamily, extModel); + printf("display:family=%X, model=%X\n", displayFamily, displayModel); + } + bool has(Type type) const + { + return (type & type_) != 0; + } +}; + +class Clock { +public: + static inline uint64 getRdtsc() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return __rdtsc(); + #else + unsigned int eax, edx; + __asm__ volatile("rdtsc" : "=a"(eax), "=d"(edx)); + return ((uint64)edx << 32) | eax; + #endif +#else + // TODO: Need another impl of Clock or rdtsc-equivalent for non-x86 cpu + return 0; +#endif + } + Clock() + : clock_(0) + , count_(0) + { + } + void begin() + { + clock_ -= getRdtsc(); + } + void end() + { + clock_ += getRdtsc(); + count_++; + } + int getCount() const { return count_; } + uint64 getClock() const { return clock_; } + void clear() { count_ = 0; clock_ = 0; } +private: + uint64 clock_; + int count_; +}; + +#ifdef XBYAK64 +const int UseRCX = 1 << 6; +const int UseRDX = 1 << 7; + +class Pack { + static const size_t maxTblNum = 15; + const Xbyak::Reg64 *tbl_[maxTblNum]; + size_t n_; +public: + Pack() : tbl_(), n_(0) {} + Pack(const Xbyak::Reg64 *tbl, size_t n) { init(tbl, n); } + Pack(const Pack& rhs) + : n_(rhs.n_) + { + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + } + Pack& operator=(const Pack& rhs) + { + n_ = rhs.n_; + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + return *this; + } + Pack(const Xbyak::Reg64& t0) + { n_ = 1; tbl_[0] = &t0; } + Pack(const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 2; tbl_[0] = &t0; tbl_[1] = &t1; } + Pack(const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 3; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; } + Pack(const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 4; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; } + Pack(const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 5; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; } + Pack(const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 6; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; } + Pack(const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 7; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; } + Pack(const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 8; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; } + Pack(const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 9; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; } + Pack(const Xbyak::Reg64& t9, const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 10; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; tbl_[9] = &t9; } + Pack& append(const Xbyak::Reg64& t) + { + if (n_ == maxTblNum) { + fprintf(stderr, "ERR Pack::can't append\n"); + throw Error(ERR_BAD_PARAMETER); + } + tbl_[n_++] = &t; + return *this; + } + void init(const Xbyak::Reg64 *tbl, size_t n) + { + if (n > maxTblNum) { + fprintf(stderr, "ERR Pack::init bad n=%d\n", (int)n); + throw Error(ERR_BAD_PARAMETER); + } + n_ = n; + for (size_t i = 0; i < n; i++) { + tbl_[i] = &tbl[i]; + } + } + const Xbyak::Reg64& operator[](size_t n) const + { + if (n >= n_) { + fprintf(stderr, "ERR Pack bad n=%d(%d)\n", (int)n, (int)n_); + throw Error(ERR_BAD_PARAMETER); + } + return *tbl_[n]; + } + size_t size() const { return n_; } + /* + get tbl[pos, pos + num) + */ + Pack sub(size_t pos, size_t num = size_t(-1)) const + { + if (num == size_t(-1)) num = n_ - pos; + if (pos + num > n_) { + fprintf(stderr, "ERR Pack::sub bad pos=%d, num=%d\n", (int)pos, (int)num); + throw Error(ERR_BAD_PARAMETER); + } + Pack pack; + pack.n_ = num; + for (size_t i = 0; i < num; i++) { + pack.tbl_[i] = tbl_[pos + i]; + } + return pack; + } + void put() const + { + for (size_t i = 0; i < n_; i++) { + printf("%s ", tbl_[i]->toString()); + } + printf("\n"); + } +}; + +class StackFrame { +#ifdef XBYAK64_WIN + static const int noSaveNum = 6; + static const int rcxPos = 0; + static const int rdxPos = 1; +#else + static const int noSaveNum = 8; + static const int rcxPos = 3; + static const int rdxPos = 2; +#endif + static const int maxRegNum = 14; // maxRegNum = 16 - rsp - rax + Xbyak::CodeGenerator *code_; + int pNum_; + int tNum_; + bool useRcx_; + bool useRdx_; + int saveNum_; + int P_; + bool makeEpilog_; + Xbyak::Reg64 pTbl_[4]; + Xbyak::Reg64 tTbl_[maxRegNum]; + Pack p_; + Pack t_; + StackFrame(const StackFrame&); + void operator=(const StackFrame&); +public: + const Pack& p; + const Pack& t; + /* + make stack frame + @param sf [in] this + @param pNum [in] num of function parameter(0 <= pNum <= 4) + @param tNum [in] num of temporary register(0 <= tNum, with UseRCX, UseRDX) #{pNum + tNum [+rcx] + [rdx]} <= 14 + @param stackSizeByte [in] local stack size + @param makeEpilog [in] automatically call close() if true + + you can use + rax + gp0, ..., gp(pNum - 1) + gt0, ..., gt(tNum-1) + rcx if tNum & UseRCX + rdx if tNum & UseRDX + rsp[0..stackSizeByte - 1] + */ + StackFrame(Xbyak::CodeGenerator *code, int pNum, int tNum = 0, int stackSizeByte = 0, bool makeEpilog = true) + : code_(code) + , pNum_(pNum) + , tNum_(tNum & ~(UseRCX | UseRDX)) + , useRcx_((tNum & UseRCX) != 0) + , useRdx_((tNum & UseRDX) != 0) + , saveNum_(0) + , P_(0) + , makeEpilog_(makeEpilog) + , p(p_) + , t(t_) + { + using namespace Xbyak; + if (pNum < 0 || pNum > 4) throw Error(ERR_BAD_PNUM); + const int allRegNum = pNum + tNum_ + (useRcx_ ? 1 : 0) + (useRdx_ ? 1 : 0); + if (tNum_ < 0 || allRegNum > maxRegNum) throw Error(ERR_BAD_TNUM); + const Reg64& _rsp = code->rsp; + saveNum_ = (std::max)(0, allRegNum - noSaveNum); + const int *tbl = getOrderTbl() + noSaveNum; + for (int i = 0; i < saveNum_; i++) { + code->push(Reg64(tbl[i])); + } + P_ = (stackSizeByte + 7) / 8; + if (P_ > 0 && (P_ & 1) == (saveNum_ & 1)) P_++; // (rsp % 16) == 8, then increment P_ for 16 byte alignment + P_ *= 8; + if (P_ > 0) code->sub(_rsp, P_); + int pos = 0; + for (int i = 0; i < pNum; i++) { + pTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + for (int i = 0; i < tNum_; i++) { + tTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + if (useRcx_ && rcxPos < pNum) code_->mov(code_->r10, code_->rcx); + if (useRdx_ && rdxPos < pNum) code_->mov(code_->r11, code_->rdx); + p_.init(pTbl_, pNum); + t_.init(tTbl_, tNum_); + } + /* + make epilog manually + @param callRet [in] call ret() if true + */ + void close(bool callRet = true) + { + using namespace Xbyak; + const Reg64& _rsp = code_->rsp; + const int *tbl = getOrderTbl() + noSaveNum; + if (P_ > 0) code_->add(_rsp, P_); + for (int i = 0; i < saveNum_; i++) { + code_->pop(Reg64(tbl[saveNum_ - 1 - i])); + } + + if (callRet) code_->ret(); + } + ~StackFrame() + { + if (!makeEpilog_) return; + try { + close(); + } catch (std::exception& e) { + printf("ERR:StackFrame %s\n", e.what()); + //exit(1); + } + } +private: + const int *getOrderTbl() const + { + using namespace Xbyak; + static const int tbl[] = { +#ifdef XBYAK64_WIN + Operand::RCX, Operand::RDX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, Operand::RDI, Operand::RSI, +#else + Operand::RDI, Operand::RSI, Operand::RDX, Operand::RCX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, +#endif + Operand::RBX, Operand::RBP, Operand::R12, Operand::R13, Operand::R14, Operand::R15 + }; + return &tbl[0]; + } + int getRegIdx(int& pos) const + { + assert(pos < maxRegNum); + using namespace Xbyak; + const int *tbl = getOrderTbl(); + int r = tbl[pos++]; + if (useRcx_) { + if (r == Operand::RCX) { return Operand::R10; } + if (r == Operand::R10) { r = tbl[pos++]; } + } + if (useRdx_) { + if (r == Operand::RDX) { return Operand::R11; } + if (r == Operand::R11) { return tbl[pos++]; } + } + return r; + } +}; +#endif + +} } // end of util +#endif diff --git a/thirdparty/oidn/weights/rtlightmap_hdr.tza b/thirdparty/oidn/weights/rtlightmap_hdr.tza new file mode 100644 index 0000000000..12459a33bc Binary files /dev/null and b/thirdparty/oidn/weights/rtlightmap_hdr.tza differ diff --git a/thirdparty/r128/r128.h b/thirdparty/r128/r128.h new file mode 100644 index 0000000000..58933d7638 --- /dev/null +++ b/thirdparty/r128/r128.h @@ -0,0 +1,2124 @@ +/* +r128.h: 128-bit (64.64) signed fixed-point arithmetic. Version 1.4.3 + +COMPILATION +----------- +Drop this header file somewhere in your project and include it wherever it is +needed. There is no separate .c file for this library. To get the code, in ONE +file in your project, put: + +#define R128_IMPLEMENTATION + +before you include this file. You may also provide a definition for R128_ASSERT +to force the library to use a custom assert macro. + +COMPILER/LIBRARY SUPPORT +------------------------ +This library requires a C89 compiler with support for 64-bit integers. If your +compiler does not support the long long data type, the R128_U64, etc. macros +must be set appropriately. On x86 and x64 targets, Intel intrinsics are used +for speed. If your compiler does not support these intrinsics, you can add +#define R128_STDC_ONLY +in your implementation file before including r128.h. + +The only C runtime library functionality used by this library is . +This can be avoided by defining an R128_ASSERT macro in your implementation +file. Since this library uses 64-bit arithmetic, this may implicitly add a +runtime library dependency on 32-bit platforms. + +C++ SUPPORT +----------- +Operator overloads are supplied for C++ files that include this file. Since all +C++ functions are declared inline (or static inline), the R128_IMPLEMENTATION +file can be either C++ or C. + +LICENSE +------- +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. +*/ + +#ifndef H_R128_H +#define H_R128_H + +#include + +// 64-bit integer support +// If your compiler does not have stdint.h, add appropriate defines for these macros. +#if defined(_MSC_VER) && (_MSC_VER < 1600) +# define R128_S32 __int32 +# define R128_U32 unsigned __int32 +# define R128_S64 __int64 +# define R128_U64 unsigned __int64 +# define R128_LIT_S64(x) x##i64 +# define R128_LIT_U64(x) x##ui64 +#else +# include +# define R128_S32 int32_t +# define R128_U32 uint32_t +# define R128_S64 int64_t +# define R128_U64 long long unsigned int +# define R128_LIT_S64(x) x##ll +# define R128_LIT_U64(x) x##ull +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct R128 { + R128_U64 lo; + R128_U64 hi; + +#ifdef __cplusplus + R128(); + R128(double); + R128(int); + R128(R128_S64); + R128(R128_U64 low, R128_U64 high); + + operator double() const; + operator R128_S64() const; + operator int() const; + operator bool() const; + + bool operator!() const; + R128 operator~() const; + R128 operator-() const; + R128 &operator|=(const R128 &rhs); + R128 &operator&=(const R128 &rhs); + R128 &operator^=(const R128 &rhs); + R128 &operator+=(const R128 &rhs); + R128 &operator-=(const R128 &rhs); + R128 &operator*=(const R128 &rhs); + R128 &operator/=(const R128 &rhs); + R128 &operator%=(const R128 &rhs); + R128 &operator<<=(int amount); + R128 &operator>>=(int amount); +#endif //__cplusplus +} R128; + +// Type conversion +extern void r128FromInt(R128 *dst, R128_S64 v); +extern void r128FromFloat(R128 *dst, double v); +extern R128_S64 r128ToInt(const R128 *v); +extern double r128ToFloat(const R128 *v); + +// Copy +extern void r128Copy(R128 *dst, const R128 *src); + +// Negate +extern void r128Neg(R128 *dst, const R128 *src); + +// Bitwise operations +extern void r128Not(R128 *dst, const R128 *src); // ~a +extern void r128Or(R128 *dst, const R128 *a, const R128 *b); // a | b +extern void r128And(R128 *dst, const R128 *a, const R128 *b); // a & b +extern void r128Xor(R128 *dst, const R128 *a, const R128 *b); // a ^ b +extern void r128Shl(R128 *dst, const R128 *src, int amount); // shift left by amount mod 128 +extern void r128Shr(R128 *dst, const R128 *src, int amount); // shift right logical by amount mod 128 +extern void r128Sar(R128 *dst, const R128 *src, int amount); // shift right arithmetic by amount mod 128 + +// Arithmetic +extern void r128Add(R128 *dst, const R128 *a, const R128 *b); // a + b +extern void r128Sub(R128 *dst, const R128 *a, const R128 *b); // a - b +extern void r128Mul(R128 *dst, const R128 *a, const R128 *b); // a * b +extern void r128Div(R128 *dst, const R128 *a, const R128 *b); // a / b +extern void r128Mod(R128 *dst, const R128 *a, const R128 *b); // a - toInt(a / b) * b + +extern void r128Sqrt(R128 *dst, const R128 *v); // sqrt(v) +extern void r128Rsqrt(R128 *dst, const R128 *v); // 1 / sqrt(v) + +// Comparison +extern int r128Cmp(const R128 *a, const R128 *b); // sign of a-b +extern void r128Min(R128 *dst, const R128 *a, const R128 *b); +extern void r128Max(R128 *dst, const R128 *a, const R128 *b); +extern void r128Floor(R128 *dst, const R128 *v); +extern void r128Ceil(R128 *dst, const R128 *v); +extern int r128IsNeg(const R128 *v); // quick check for < 0 + +// String conversion +// +typedef enum R128ToStringSign { + R128ToStringSign_Default, // no sign character for positive values + R128ToStringSign_Space, // leading space for positive values + R128ToStringSign_Plus, // leading '+' for positive values +} R128ToStringSign; + +// Formatting options for use with r128ToStringOpt. The "defaults" correspond +// to a format string of "%f". +// +typedef struct R128ToStringFormat { + // sign character for positive values. Default is R128ToStringSign_Default. + R128ToStringSign sign; + + // minimum number of characters to write. Default is 0. + int width; + + // place to the right of the decimal at which rounding is performed. If negative, + // a maximum of 20 decimal places will be written, with no trailing zeroes. + // (20 places is sufficient to ensure that r128FromString will convert back to the + // original value.) Default is -1. NOTE: This is not the same default that the C + // standard library uses for %f. + int precision; + + // If non-zero, pads the output string with leading zeroes if the final result is + // fewer than width characters. Otherwise, leading spaces are used. Default is 0. + int zeroPad; + + // Always print a decimal point, even if the value is an integer. Default is 0. + int decimal; + + // Left-align output if width specifier requires padding. + // Default is 0 (right align). + int leftAlign; +} R128ToStringFormat; + +// r128ToStringOpt: convert R128 to a decimal string, with formatting. +// +// dst and dstSize: specify the buffer to write into. At most dstSize bytes will be written +// (including null terminator). No additional rounding is performed if dstSize is not large +// enough to hold the entire string. +// +// opt: an R128ToStringFormat struct (q.v.) with formatting options. +// +// Uses the R128_decimal global as the decimal point character. +// Always writes a null terminator, even if the destination buffer is not large enough. +// +// Number of bytes that will be written (i.e. how big does dst need to be?): +// If width is specified: width + 1 bytes. +// If precision is specified: at most precision + 22 bytes. +// If neither is specified: at most 42 bytes. +// +// Returns the number of bytes that would have been written if dst was sufficiently large, +// not including the final null terminator. +// +extern int r128ToStringOpt(char *dst, size_t dstSize, const R128 *v, const R128ToStringFormat *opt); + +// r128ToStringf: convert R128 to a decimal string, with formatting. +// +// dst and dstSize: specify the buffer to write into. At most dstSize bytes will be written +// (including null terminator). +// +// format: a printf-style format specifier, as one would use with floating point types. +// e.g. "%+5.2f". (The leading % and trailing f are optional.) +// NOTE: This is NOT a full replacement for sprintf. Any characters in the format string +// that do not correspond to a format placeholder are ignored. +// +// Uses the R128_decimal global as the decimal point character. +// Always writes a null terminator, even if the destination buffer is not large enough. +// +// Number of bytes that will be written (i.e. how big does dst need to be?): +// If the precision field is specified: at most max(width, precision + 21) + 1 bytes +// Otherwise: at most max(width, 41) + 1 bytes. +// +// Returns the number of bytes that would have been written if dst was sufficiently large, +// not including the final null terminator. +// +extern int r128ToStringf(char *dst, size_t dstSize, const char *format, const R128 *v); + +// r128ToString: convert R128 to a decimal string, with default formatting. +// Equivalent to r128ToStringf(dst, dstSize, "%f", v). +// +// Uses the R128_decimal global as the decimal point character. +// Always writes a null terminator, even if the destination buffer is not large enough. +// +// Will write at most 42 bytes (including NUL) to dst. +// +// Returns the number of bytes that would have been written if dst was sufficiently large, +// not including the final null terminator. +// +extern int r128ToString(char *dst, size_t dstSize, const R128 *v); + +// r128FromString: Convert string to R128. +// +// The string can be formatted either as a decimal number with optional sign +// or as hexadecimal with a prefix of 0x or 0X. +// +// endptr, if not NULL, is set to the character following the last character +// used in the conversion. +// +extern void r128FromString(R128 *dst, const char *s, char **endptr); + +// Constants +extern const R128 R128_min; // minimum (most negative) value +extern const R128 R128_max; // maximum (most positive) value +extern const R128 R128_smallest; // smallest positive value +extern const R128 R128_zero; // zero +extern const R128 R128_one; // 1.0 + +extern char R128_decimal; // decimal point character used by r128From/ToString. defaults to '.' + +#ifdef __cplusplus +} + +#include +namespace std { +template<> +struct numeric_limits +{ + static const bool is_specialized = true; + + static R128 min() throw() { return R128_min; } + static R128 max() throw() { return R128_max; } + + static const int digits = 127; + static const int digits10 = 38; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const int radix = 2; + static R128 epsilon() throw() { return R128_smallest; } + static R128 round_error() throw() { return R128_one; } + + static const int min_exponent = 0; + static const int min_exponent10 = 0; + static const int max_exponent = 0; + static const int max_exponent10 = 0; + + static const bool has_infinity = false; + static const bool has_quiet_NaN = false; + static const bool has_signaling_NaN = false; + static const float_denorm_style has_denorm = denorm_absent; + static const bool has_denorm_loss = false; + + static R128 infinity() throw() { return R128_zero; } + static R128 quiet_NaN() throw() { return R128_zero; } + static R128 signaling_NaN() throw() { return R128_zero; } + static R128 denorm_min() throw() { return R128_zero; } + + static const bool is_iec559 = false; + static const bool is_bounded = true; + static const bool is_modulo = true; + + static const bool traps = numeric_limits::traps; + static const bool tinyness_before = false; + static const float_round_style round_style = round_toward_zero; +}; +} //namespace std + +inline R128::R128() {} + +inline R128::R128(double v) +{ + r128FromFloat(this, v); +} + +inline R128::R128(int v) +{ + r128FromInt(this, v); +} + +inline R128::R128(R128_S64 v) +{ + r128FromInt(this, v); +} + +inline R128::R128(R128_U64 low, R128_U64 high) +{ + lo = low; + hi = high; +} + +inline R128::operator double() const +{ + return r128ToFloat(this); +} + +inline R128::operator R128_S64() const +{ + return r128ToInt(this); +} + +inline R128::operator int() const +{ + return (int) r128ToInt(this); +} + +inline R128::operator bool() const +{ + return lo || hi; +} + +inline bool R128::operator!() const +{ + return !lo && !hi; +} + +inline R128 R128::operator~() const +{ + R128 r; + r128Not(&r, this); + return r; +} + +inline R128 R128::operator-() const +{ + R128 r; + r128Neg(&r, this); + return r; +} + +inline R128 &R128::operator|=(const R128 &rhs) +{ + r128Or(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator&=(const R128 &rhs) +{ + r128And(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator^=(const R128 &rhs) +{ + r128Xor(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator+=(const R128 &rhs) +{ + r128Add(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator-=(const R128 &rhs) +{ + r128Sub(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator*=(const R128 &rhs) +{ + r128Mul(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator/=(const R128 &rhs) +{ + r128Div(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator%=(const R128 &rhs) +{ + r128Mod(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator<<=(int amount) +{ + r128Shl(this, this, amount); + return *this; +} + +inline R128 &R128::operator>>=(int amount) +{ + r128Sar(this, this, amount); + return *this; +} + +static inline R128 operator|(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r |= rhs; +} + +static inline R128 operator&(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r &= rhs; +} + +static inline R128 operator^(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r ^= rhs; +} + +static inline R128 operator+(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r += rhs; +} + +static inline R128 operator-(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r -= rhs; +} + +static inline R128 operator*(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r *= rhs; +} + +static inline R128 operator/(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r /= rhs; +} + +static inline R128 operator%(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r %= rhs; +} + +static inline R128 operator<<(const R128 &lhs, int amount) +{ + R128 r(lhs); + return r <<= amount; +} + +static inline R128 operator>>(const R128 &lhs, int amount) +{ + R128 r(lhs); + return r >>= amount; +} + +static inline bool operator<(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) < 0; +} + +static inline bool operator>(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) > 0; +} + +static inline bool operator<=(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) <= 0; +} + +static inline bool operator>=(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) >= 0; +} + +static inline bool operator==(const R128 &lhs, const R128 &rhs) +{ + return lhs.lo == rhs.lo && lhs.hi == rhs.hi; +} + +static inline bool operator!=(const R128 &lhs, const R128 &rhs) +{ + return lhs.lo != rhs.lo || lhs.hi != rhs.hi; +} + +#endif //__cplusplus +#endif //H_R128_H + +#ifdef R128_IMPLEMENTATION + +#ifdef R128_DEBUG_VIS +# define R128_DEBUG_SET(x) r128ToString(R128_last, sizeof(R128_last), x) +#else +# define R128_DEBUG_SET(x) +#endif + +#define R128_SET2(x, l, h) do { (x)->lo = (R128_U64)(l); (x)->hi = (R128_U64)(h); } while(0) +#define R128_R0(x) ((R128_U32)(x)->lo) +#define R128_R2(x) ((R128_U32)(x)->hi) +#if defined(_M_IX86) +// workaround: MSVC x86's handling of 64-bit values is not great +# define R128_SET4(x, r0, r1, r2, r3) do { \ + ((R128_U32*)&(x)->lo)[0] = (R128_U32)(r0); \ + ((R128_U32*)&(x)->lo)[1] = (R128_U32)(r1); \ + ((R128_U32*)&(x)->hi)[0] = (R128_U32)(r2); \ + ((R128_U32*)&(x)->hi)[1] = (R128_U32)(r3); \ + } while(0) +# define R128_R1(x) (((R128_U32*)&(x)->lo)[1]) +# define R128_R3(x) (((R128_U32*)&(x)->hi)[1]) +#else +# define R128_SET4(x, r0, r1, r2, r3) do { (x)->lo = (R128_U64)(r0) | ((R128_U64)(r1) << 32); \ + (x)->hi = (R128_U64)(r2) | ((R128_U64)(r3) << 32); } while(0) +# define R128_R1(x) ((R128_U32)((x)->lo >> 32)) +# define R128_R3(x) ((R128_U32)((x)->hi >> 32)) +#endif + +#if defined(_M_X64) +# define R128_INTEL 1 +# define R128_64BIT 1 +# ifndef R128_STDC_ONLY +# include +# endif +#elif defined(__x86_64__) +# define R128_INTEL 1 +# define R128_64BIT 1 +# ifndef R128_STDC_ONLY +# include +# endif +#elif defined(_M_IX86) +# define R128_INTEL 1 +# ifndef R128_STDC_ONLY +# include +# endif +#elif defined(__i386__) +# define R128_INTEL 1 +# ifndef R128_STDC_ONLY +# include +# endif +#elif defined(_M_ARM) +# ifndef R128_STDC_ONLY +# include +# endif +#elif defined(_M_ARM64) +# define R128_64BIT 1 +# ifndef R128_STDC_ONLY +# include +# endif +#elif defined(__aarch64__) +# define R128_64BIT 1 +#endif + +#ifndef R128_INTEL +# define R128_INTEL 0 +#endif + +#ifndef R128_64BIT +# define R128_64BIT 0 +#endif + +#ifndef R128_ASSERT +# include +# define R128_ASSERT(x) assert(x) +#endif + +#include // for NULL + +static const R128ToStringFormat R128__defaultFormat = { + R128ToStringSign_Default, + 0, + -1, + 0, + 0, + 0 +}; + +const R128 R128_min = { 0, R128_LIT_U64(0x8000000000000000) }; +const R128 R128_max = { R128_LIT_U64(0xffffffffffffffff), R128_LIT_U64(0x7fffffffffffffff) }; +const R128 R128_smallest = { 1, 0 }; +const R128 R128_zero = { 0, 0 }; +const R128 R128_one = { 0, 1 }; +char R128_decimal = '.'; +#ifdef R128_DEBUG_VIS +char R128_last[42]; +#endif + +static int r128__clz64(R128_U64 x) +{ +#if defined(R128_STDC_ONLY) + R128_U64 n = 64, y; + y = x >> 32; if (y) { n -= 32; x = y; } + y = x >> 16; if (y) { n -= 16; x = y; } + y = x >> 8; if (y) { n -= 8; x = y; } + y = x >> 4; if (y) { n -= 4; x = y; } + y = x >> 2; if (y) { n -= 2; x = y; } + y = x >> 1; if (y) { n -= 1; x = y; } + return (int)(n - x); +#elif defined(_M_X64) || defined(_M_ARM64) + unsigned long idx; + if (_BitScanReverse64(&idx, x)) { + return 63 - (int)idx; + } else { + return 64; + } +#elif defined(_MSC_VER) + unsigned long idx; + if (_BitScanReverse(&idx, (R128_U32)(x >> 32))) { + return 31 - (int)idx; + } else if (_BitScanReverse(&idx, (R128_U32)x)) { + return 63 - (int)idx; + } else { + return 64; + } +#else + return x ? __builtin_clzll(x) : 64; +#endif +} + +#if !R128_64BIT +// 32*32->64 +static R128_U64 r128__umul64(R128_U32 a, R128_U32 b) +{ +# if defined(_M_IX86) && !defined(R128_STDC_ONLY) + return __emulu(a, b); +# elif defined(_M_ARM) && !defined(R128_STDC_ONLY) + return _arm_umull(a, b); +# else + return a * (R128_U64)b; +# endif +} + +// 64/32->32 +static R128_U32 r128__udiv64(R128_U32 nlo, R128_U32 nhi, R128_U32 d, R128_U32 *rem) +{ +# if defined(_M_IX86) && (_MSC_VER >= 1920) && !defined(R128_STDC_ONLY) + unsigned __int64 n = ((unsigned __int64)nhi << 32) | nlo; + return _udiv64(n, d, rem); +# elif defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + mov eax, nlo + mov edx, nhi + div d + mov ecx, rem + mov dword ptr [ecx], edx + } +# elif defined(__i386__) && !defined(R128_STDC_ONLY) + R128_U32 q, r; + __asm("divl %4" + : "=a"(q), "=d"(r) + : "a"(nlo), "d"(nhi), "X"(d)); + *rem = r; + return q; +# else + R128_U64 n64 = ((R128_U64)nhi << 32) | nlo; + *rem = (R128_U32)(n64 % d); + return (R128_U32)(n64 / d); +# endif +} +#elif !defined(_M_X64) || defined(R128_STDC_ONLY) +#define r128__umul64(a, b) ((a) * (R128_U64)(b)) +/*static R128_U32 r128__udiv64(R128_U32 nlo, R128_U32 nhi, R128_U32 d, R128_U32 *rem) +{ + R128_U64 n64 = ((R128_U64)nhi << 32) | nlo; + *rem = (R128_U32)(n64 % d); + return (R128_U32)(n64 / d); +}*/ +#endif //!R128_64BIT + +static void r128__neg(R128 *dst, const R128 *src) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if R128_INTEL && !defined(R128_STDC_ONLY) + { + unsigned char carry = 0; +# if R128_64BIT + carry = _addcarry_u64(carry, ~src->lo, 1, &dst->lo); + carry = _addcarry_u64(carry, ~src->hi, 0, &dst->hi); +# else + R128_U32 r0, r1, r2, r3; + carry = _addcarry_u32(carry, ~R128_R0(src), 1, &r0); + carry = _addcarry_u32(carry, ~R128_R1(src), 0, &r1); + carry = _addcarry_u32(carry, ~R128_R2(src), 0, &r2); + carry = _addcarry_u32(carry, ~R128_R3(src), 0, &r3); + R128_SET4(dst, r0, r1, r2, r3); +# endif //R128_64BIT + } +#else + if (src->lo) { + dst->lo = ~src->lo + 1; + dst->hi = ~src->hi; + } else { + dst->lo = 0; + dst->hi = ~src->hi + 1; + } +#endif //R128_INTEL +} + +// 64*64->128 +static void r128__umul128(R128 *dst, R128_U64 a, R128_U64 b) +{ +#if defined(_M_X64) && !defined(R128_STDC_ONLY) + dst->lo = _umul128(a, b, &dst->hi); +#elif R128_64BIT && !defined(_MSC_VER) && !defined(R128_STDC_ONLY) + unsigned __int128 p0 = a * (unsigned __int128)b; + dst->hi = (R128_U64)(p0 >> 64); + dst->lo = (R128_U64)p0; +#else + R128_U32 alo = (R128_U32)a; + R128_U32 ahi = (R128_U32)(a >> 32); + R128_U32 blo = (R128_U32)b; + R128_U32 bhi = (R128_U32)(b >> 32); + R128_U64 p0, p1, p2, p3; + + p0 = r128__umul64(alo, blo); + p1 = r128__umul64(alo, bhi); + p2 = r128__umul64(ahi, blo); + p3 = r128__umul64(ahi, bhi); + + { +#if R128_INTEL && !defined(R128_STDC_ONLY) + R128_U32 r0, r1, r2, r3; + unsigned char carry; + + r0 = (R128_U32)(p0); + r1 = (R128_U32)(p0 >> 32); + r2 = (R128_U32)(p1 >> 32); + r3 = (R128_U32)(p3 >> 32); + + carry = _addcarry_u32(0, r1, (R128_U32)p1, &r1); + carry = _addcarry_u32(carry, r2, (R128_U32)(p2 >> 32), &r2); + _addcarry_u32(carry, r3, 0, &r3); + carry = _addcarry_u32(0, r1, (R128_U32)p2, &r1); + carry = _addcarry_u32(carry, r2, (R128_U32)p3, &r2); + _addcarry_u32(carry, r3, 0, &r3); + + R128_SET4(dst, r0, r1, r2, r3); +#else + R128_U64 carry, lo, hi; + carry = ((R128_U64)(R128_U32)p1 + (R128_U64)(R128_U32)p2 + (p0 >> 32)) >> 32; + + lo = p0 + ((p1 + p2) << 32); + hi = p3 + ((R128_U32)(p1 >> 32) + (R128_U32)(p2 >> 32)) + carry; + + R128_SET2(dst, lo, hi); +#endif + } +#endif +} + +// 128/64->64 +#if defined(_M_X64) && (_MSC_VER < 1920) && !defined(R128_STDC_ONLY) +// MSVC x64 provides neither inline assembly nor (pre-2019) a div intrinsic, so we do fake +// "inline assembly" to avoid long division or outline assembly. +#pragma code_seg(".text") +__declspec(allocate(".text")) static const unsigned char r128__udiv128Code[] = { + 0x48, 0x8B, 0xC1, //mov rax, rcx + 0x49, 0xF7, 0xF0, //div rax, r8 + 0x49, 0x89, 0x11, //mov qword ptr [r9], rdx + 0xC3 //ret +}; +typedef R128_U64 (*r128__udiv128Proc)(R128_U64 nlo, R128_U64 nhi, R128_U64 d, R128_U64 *rem); +static const r128__udiv128Proc r128__udiv128 = (r128__udiv128Proc)(void*)r128__udiv128Code; +#else +static R128_U64 r128__udiv128(R128_U64 nlo, R128_U64 nhi, R128_U64 d, R128_U64 *rem) +{ +#if defined(_M_X64) && !defined(R128_STDC_ONLY) + return _udiv128(nhi, nlo, d, rem); +#elif defined(__x86_64__) && !defined(R128_STDC_ONLY) + R128_U64 q, r; + __asm("divq %4" + : "=a"(q), "=d"(r) + : "a"(nlo), "d"(nhi), "X"(d)); + *rem = r; + return q; +#else + R128_U64 tmp; + R128_U32 d0, d1; + R128_U32 n3, n2, n1, n0; + R128_U32 q0, q1; + R128_U32 r; + int shift; + + R128_ASSERT(d != 0); //division by zero + R128_ASSERT(nhi < d); //overflow + + // normalize + shift = r128__clz64(d); + + if (shift) { + R128 tmp128; + R128_SET2(&tmp128, nlo, nhi); + r128Shl(&tmp128, &tmp128, shift); + n3 = R128_R3(&tmp128); + n2 = R128_R2(&tmp128); + n1 = R128_R1(&tmp128); + n0 = R128_R0(&tmp128); + d <<= shift; + } else { + n3 = (R128_U32)(nhi >> 32); + n2 = (R128_U32)nhi; + n1 = (R128_U32)(nlo >> 32); + n0 = (R128_U32)nlo; + } + + d1 = (R128_U32)(d >> 32); + d0 = (R128_U32)d; + + // first digit + R128_ASSERT(n3 <= d1); + if (n3 < d1) { + q1 = r128__udiv64(n2, n3, d1, &r); + } else { + q1 = 0xffffffffu; + r = n2 + d1; + } +refine1: + if (r128__umul64(q1, d0) > ((R128_U64)r << 32) + n1) { + --q1; + if (r < ~d1 + 1) { + r += d1; + goto refine1; + } + } + + tmp = ((R128_U64)n2 << 32) + n1 - (r128__umul64(q1, d0) + (r128__umul64(q1, d1) << 32)); + n2 = (R128_U32)(tmp >> 32); + n1 = (R128_U32)tmp; + + // second digit + R128_ASSERT(n2 <= d1); + if (n2 < d1) { + q0 = r128__udiv64(n1, n2, d1, &r); + } else { + q0 = 0xffffffffu; + r = n1 + d1; + } +refine0: + if (r128__umul64(q0, d0) > ((R128_U64)r << 32) + n0) { + --q0; + if (r < ~d1 + 1) { + r += d1; + goto refine0; + } + } + + tmp = ((R128_U64)n1 << 32) + n0 - (r128__umul64(q0, d0) + (r128__umul64(q0, d1) << 32)); + n1 = (R128_U32)(tmp >> 32); + n0 = (R128_U32)tmp; + + *rem = (((R128_U64)n1 << 32) + n0) >> shift; + return ((R128_U64)q1 << 32) + q0; +#endif +} +#endif + +static int r128__ucmp(const R128 *a, const R128 *b) +{ + if (a->hi != b->hi) { + if (a->hi > b->hi) { + return 1; + } else { + return -1; + } + } else { + if (a->lo == b->lo) { + return 0; + } else if (a->lo > b->lo) { + return 1; + } else { + return -1; + } + } +} + +static void r128__umul(R128 *dst, const R128 *a, const R128 *b) +{ +#if defined(_M_X64) && !defined(R128_STDC_ONLY) + R128_U64 t0, t1; + R128_U64 lo, hi = 0; + unsigned char carry; + + t0 = _umul128(a->lo, b->lo, &t1); + carry = _addcarry_u64(0, t1, t0 >> 63, &lo); + _addcarry_u64(carry, hi, hi, &hi); + + t0 = _umul128(a->lo, b->hi, &t1); + carry = _addcarry_u64(0, lo, t0, &lo); + _addcarry_u64(carry, hi, t1, &hi); + + t0 = _umul128(a->hi, b->lo, &t1); + carry = _addcarry_u64(0, lo, t0, &lo); + _addcarry_u64(carry, hi, t1, &hi); + + t0 = _umul128(a->hi, b->hi, &t1); + hi += t0; + + R128_SET2(dst, lo, hi); +#elif defined(__x86_64__) && !defined(R128_STDC_ONLY) + unsigned __int128 p0, p1, p2, p3; + p0 = a->lo * (unsigned __int128)b->lo; + p1 = a->lo * (unsigned __int128)b->hi; + p2 = a->hi * (unsigned __int128)b->lo; + p3 = a->hi * (unsigned __int128)b->hi; + + p0 = (p3 << 64) + p2 + p1 + (p0 >> 64) + ((R128_U64)p0 >> 63); + dst->lo = (R128_U64)p0; + dst->hi = (R128_U64)(p0 >> 64); +#else + R128 p0, p1, p2, p3, round; + + r128__umul128(&p0, a->lo, b->lo); + round.hi = 0; round.lo = p0.lo >> 63; + p0.lo = p0.hi; p0.hi = 0; //r128Shr(&p0, &p0, 64); + r128Add(&p0, &p0, &round); + + r128__umul128(&p1, a->hi, b->lo); + r128Add(&p0, &p0, &p1); + + r128__umul128(&p2, a->lo, b->hi); + r128Add(&p0, &p0, &p2); + + r128__umul128(&p3, a->hi, b->hi); + p3.hi = p3.lo; p3.lo = 0; //r128Shl(&p3, &p3, 64); + r128Add(&p0, &p0, &p3); + + R128_SET2(dst, p0.lo, p0.hi); +#endif +} + +// Shift d left until the high bit is set, and shift n left by the same amount. +// returns non-zero on overflow. +static int r128__norm(R128 *n, R128 *d, R128_U64 *n2) +{ + R128_U64 d0, d1; + R128_U64 n0, n1; + int shift; + + d1 = d->hi; + d0 = d->lo; + n1 = n->hi; + n0 = n->lo; + + if (d1) { + shift = r128__clz64(d1); + if (shift) { + d1 = (d1 << shift) | (d0 >> (64 - shift)); + d0 = d0 << shift; + *n2 = n1 >> (64 - shift); + n1 = (n1 << shift) | (n0 >> (64 - shift)); + n0 = n0 << shift; + } else { + *n2 = 0; + } + } else { + shift = r128__clz64(d0); + if (r128__clz64(n1) <= shift) { + return 1; // overflow + } + + if (shift) { + d1 = d0 << shift; + d0 = 0; + *n2 = (n1 << shift) | (n0 >> (64 - shift)); + n1 = n0 << shift; + n0 = 0; + } else { + d1 = d0; + d0 = 0; + *n2 = n1; + n1 = n0; + n0 = 0; + } + } + + R128_SET2(n, n0, n1); + R128_SET2(d, d0, d1); + return 0; +} + +static void r128__udiv(R128 *quotient, const R128 *dividend, const R128 *divisor) +{ + R128 tmp; + R128_U64 d0, d1; + R128_U64 n1, n2, n3; + R128 q; + + R128_ASSERT(dividend != NULL); + R128_ASSERT(divisor != NULL); + R128_ASSERT(quotient != NULL); + R128_ASSERT(divisor->hi != 0 || divisor->lo != 0); // divide by zero + + // scale dividend and normalize + { + R128 n, d; + R128_SET2(&n, dividend->lo, dividend->hi); + R128_SET2(&d, divisor->lo, divisor->hi); + if (r128__norm(&n, &d, &n3)) { + R128_SET2(quotient, R128_max.lo, R128_max.hi); + return; + } + + d1 = d.hi; + d0 = d.lo; + n2 = n.hi; + n1 = n.lo; + } + + // first digit + R128_ASSERT(n3 <= d1); + { + R128 t0, t1; + t0.lo = n1; + if (n3 < d1) { + q.hi = r128__udiv128(n2, n3, d1, &t0.hi); + } else { + q.hi = R128_LIT_U64(0xffffffffffffffff); + t0.hi = n2 + d1; + } + +refine1: + r128__umul128(&t1, q.hi, d0); + if (r128__ucmp(&t1, &t0) > 0) { + --q.hi; + if (t0.hi < ~d1 + 1) { + t0.hi += d1; + goto refine1; + } + } + } + + { + R128 t0, t1, t2; + t0.hi = n2; + t0.lo = n1; + + r128__umul128(&t1, q.hi, d0); + r128__umul128(&t2, q.hi, d1); + + t2.hi = t2.lo; t2.lo = 0; //r128Shl(&t2, &t2, 64); + r128Add(&tmp, &t1, &t2); + r128Sub(&tmp, &t0, &tmp); + } + n2 = tmp.hi; + n1 = tmp.lo; + + // second digit + R128_ASSERT(n2 <= d1); + { + R128 t0, t1; + t0.lo = 0; + if (n2 < d1) { + q.lo = r128__udiv128(n1, n2, d1, &t0.hi); + } else { + q.lo = R128_LIT_U64(0xffffffffffffffff); + t0.hi = n1 + d1; + } + + refine0: + r128__umul128(&t1, q.lo, d0); + if (r128__ucmp(&t1, &t0) > 0) { + --q.lo; + if (t0.hi < ~d1 + 1) { + t0.hi += d1; + goto refine0; + } + } + } + + R128_SET2(quotient, q.lo, q.hi); +} + +static R128_U64 r128__umod(R128 *n, R128 *d) +{ + R128_U64 d0, d1; + R128_U64 n3, n2, n1; + R128_U64 q; + + R128_ASSERT(d != NULL); + R128_ASSERT(n != NULL); + R128_ASSERT(d->hi != 0 || d->lo != 0); // divide by zero + + if (r128__norm(n, d, &n3)) { + return R128_LIT_U64(0xffffffffffffffff); + } + + d1 = d->hi; + d0 = d->lo; + n2 = n->hi; + n1 = n->lo; + + R128_ASSERT(n3 < d1); + { + R128 t0, t1; + t0.lo = n1; + q = r128__udiv128(n2, n3, d1, &t0.hi); + + refine1: + r128__umul128(&t1, q, d0); + if (r128__ucmp(&t1, &t0) > 0) { + --q; + if (t0.hi < ~d1 + 1) { + t0.hi += d1; + goto refine1; + } + } + } + + return q; +} + +static int r128__format(char *dst, size_t dstSize, const R128 *v, const R128ToStringFormat *format) +{ + char buf[128]; + R128 tmp; + R128_U64 whole; + char *cursor, *decimal, *dstp = dst; + int sign = 0; + int fullPrecision = 1; + int width, precision; + int padCnt, trail = 0; + + R128_ASSERT(dst != NULL && dstSize > 0); + R128_ASSERT(v != NULL); + R128_ASSERT(format != NULL); + + --dstSize; + + R128_SET2(&tmp, v->lo, v->hi); + if (r128IsNeg(&tmp)) { + r128__neg(&tmp, &tmp); + sign = 1; + } + + width = format->width; + if (width < 0) { + width = 0; + } + + precision = format->precision; + if (precision < 0) { + // print a maximum of 20 digits + fullPrecision = 0; + precision = 20; + } else if (precision > (int)sizeof(buf) - 21) { + trail = precision - (sizeof(buf) - 21); + precision -= trail; + } + + whole = tmp.hi; + decimal = cursor = buf; + + // fractional part first in case a carry into the whole part is required + if (tmp.lo || format->decimal) { + while (tmp.lo || (fullPrecision && precision)) { + if ((int)(cursor - buf) == precision) { + if ((R128_S64)tmp.lo < 0) { + // round up, propagate carry backwards + char *c; + for (c = cursor - 1; c >= buf; --c) { + char d = ++*c; + if (d <= '9') { + goto endfrac; + } else { + *c = '0'; + } + } + + // carry out into the whole part + whole++; + } + + break; + } + + r128__umul128(&tmp, tmp.lo, 10); + *cursor++ = (char)tmp.hi + '0'; + } + + endfrac: + if (format->decimal || precision) { + decimal = cursor; + *cursor++ = R128_decimal; + } + } + + // whole part + do { + char digit = (char)(whole % 10); + whole /= 10; + *cursor++ = digit + '0'; + } while (whole); + +#define R128__WRITE(c) do { if (dstp < dst + dstSize) *dstp = c; ++dstp; } while(0) + + padCnt = width - (int)(cursor - buf) - 1; + + // left padding + if (!format->leftAlign) { + char padChar = format->zeroPad ? '0' : ' '; + if (format->zeroPad) { + if (sign) { + R128__WRITE('-'); + } else if (format->sign == R128ToStringSign_Plus) { + R128__WRITE('+'); + } else if (format->sign == R128ToStringSign_Space) { + R128__WRITE(' '); + } else { + ++padCnt; + } + } + + for (; padCnt > 0; --padCnt) { + R128__WRITE(padChar); + } + } + + if (format->leftAlign || !format->zeroPad) { + if (sign) { + R128__WRITE('-'); + } else if (format->sign == R128ToStringSign_Plus) { + R128__WRITE('+'); + } else if (format->sign == R128ToStringSign_Space) { + R128__WRITE(' '); + } else { + ++padCnt; + } + } + + { + char *i; + + // reverse the whole part + for (i = cursor - 1; i >= decimal; --i) { + R128__WRITE(*i); + } + + // copy the fractional part + for (i = buf; i < decimal; ++i) { + R128__WRITE(*i); + } + } + + // right padding + if (format->leftAlign) { + char padChar = format->zeroPad ? '0' : ' '; + for (; padCnt > 0; --padCnt) { + R128__WRITE(padChar); + } + } + + // trailing zeroes for very large precision + while (trail--) { + R128__WRITE('0'); + } + +#undef R128__WRITE + + if (dstp <= dst + dstSize) { + *dstp = '\0'; + } else { + dst[dstSize] = '\0'; + } + return (int)(dstp - dst); +} + +void r128FromInt(R128 *dst, R128_S64 v) +{ + R128_ASSERT(dst != NULL); + dst->lo = 0; + dst->hi = (R128_U64)v; + R128_DEBUG_SET(dst); +} + +void r128FromFloat(R128 *dst, double v) +{ + R128_ASSERT(dst != NULL); + + if (v < -9223372036854775808.0) { + r128Copy(dst, &R128_min); + } else if (v >= 9223372036854775808.0) { + r128Copy(dst, &R128_max); + } else { + R128 r; + int sign = 0; + + if (v < 0) { + v = -v; + sign = 1; + } + + r.hi = (R128_U64)(R128_S64)v; + v -= (R128_S64)v; + r.lo = (R128_U64)(v * 18446744073709551616.0); + + if (sign) { + r128__neg(&r, &r); + } + + r128Copy(dst, &r); + } +} + +void r128FromString(R128 *dst, const char *s, char **endptr) +{ + R128_U64 lo = 0, hi = 0; + R128_U64 base = 10; + + int sign = 0; + + R128_ASSERT(dst != NULL); + R128_ASSERT(s != NULL); + + R128_SET2(dst, 0, 0); + + // consume whitespace + for (;;) { + if (*s == ' ' || *s == '\t' || *s == '\r' || *s == '\n' || *s == '\v') { + ++s; + } else { + break; + } + } + + // sign + if (*s == '-') { + sign = 1; + ++s; + } else if (*s == '+') { + ++s; + } + + // parse base prefix + if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) { + base = 16; + s += 2; + } + + // whole part + for (;; ++s) { + R128_U64 digit; + + if ('0' <= *s && *s <= '9') { + digit = *s - '0'; + } else if (base == 16 && 'a' <= *s && *s <= 'f') { + digit = *s - 'a' + 10; + } else if (base == 16 && 'A' <= *s && *s <= 'F') { + digit = *s - 'A' + 10; + } else { + break; + } + + hi = hi * base + digit; + } + + // fractional part + if (*s == R128_decimal) { + const char *exp = ++s; + + // find the last digit and work backwards + for (;; ++s) { + if ('0' <= *s && *s <= '9') { + } else if (base == 16 && ('a' <= *s && *s <= 'f')) { + } else if (base == 16 && ('A' <= *s && *s <= 'F')) { + } else { + break; + } + } + + for (--s; s >= exp; --s) { + R128_U64 digit, unused; + + if ('0' <= *s && *s <= '9') { + digit = *s - '0'; + } else if ('a' <= *s && *s <= 'f') { + digit = *s - 'a' + 10; + } else { + digit = *s - 'A' + 10; + } + + lo = r128__udiv128(lo, digit, base, &unused); + } + } + + R128_SET2(dst, lo, hi); + if (sign) { + r128__neg(dst, dst); + } + + if (endptr) { + *endptr = (char *) s; + } +} + +R128_S64 r128ToInt(const R128 *v) +{ + R128_ASSERT(v != NULL); + return (R128_S64)v->hi; +} + +double r128ToFloat(const R128 *v) +{ + R128 tmp; + int sign = 0; + double d; + + R128_ASSERT(v != NULL); + + R128_SET2(&tmp, v->lo, v->hi); + if (r128IsNeg(&tmp)) { + r128__neg(&tmp, &tmp); + sign = 1; + } + + d = tmp.hi + tmp.lo * (1 / 18446744073709551616.0); + if (sign) { + d = -d; + } + + return d; +} + +int r128ToStringOpt(char *dst, size_t dstSize, const R128 *v, const R128ToStringFormat *opt) +{ + return r128__format(dst, dstSize, v, opt); +} + +int r128ToStringf(char *dst, size_t dstSize, const char *format, const R128 *v) +{ + R128ToStringFormat opts; + + R128_ASSERT(dst != NULL && dstSize); + R128_ASSERT(format != NULL); + R128_ASSERT(v != NULL); + + opts.sign = R128__defaultFormat.sign; + opts.precision = R128__defaultFormat.precision; + opts.zeroPad = R128__defaultFormat.zeroPad; + opts.decimal = R128__defaultFormat.decimal; + opts.leftAlign = R128__defaultFormat.leftAlign; + + if (*format == '%') { + ++format; + } + + // flags field + for (;; ++format) { + if (*format == ' ' && opts.sign != R128ToStringSign_Plus) { + opts.sign = R128ToStringSign_Space; + } else if (*format == '+') { + opts.sign = R128ToStringSign_Plus; + } else if (*format == '0') { + opts.zeroPad = 1; + } else if (*format == '-') { + opts.leftAlign = 1; + } else if (*format == '#') { + opts.decimal = 1; + } else { + break; + } + } + + // width field + opts.width = 0; + for (;;) { + if ('0' <= *format && *format <= '9') { + opts.width = opts.width * 10 + *format++ - '0'; + } else { + break; + } + } + + // precision field + if (*format == '.') { + opts.precision = 0; + ++format; + for (;;) { + if ('0' <= *format && *format <= '9') { + opts.precision = opts.precision * 10 + *format++ - '0'; + } else { + break; + } + } + } + + return r128__format(dst, dstSize, v, &opts); +} + +int r128ToString(char *dst, size_t dstSize, const R128 *v) +{ + return r128__format(dst, dstSize, v, &R128__defaultFormat); +} + +void r128Copy(R128 *dst, const R128 *src) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + dst->lo = src->lo; + dst->hi = src->hi; + R128_DEBUG_SET(dst); +} + +void r128Neg(R128 *dst, const R128 *src) +{ + r128__neg(dst, src); + R128_DEBUG_SET(dst); +} + +void r128Not(R128 *dst, const R128 *src) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + + dst->lo = ~src->lo; + dst->hi = ~src->hi; + R128_DEBUG_SET(dst); +} + +void r128Or(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + dst->lo = a->lo | b->lo; + dst->hi = a->hi | b->hi; + R128_DEBUG_SET(dst); +} + +void r128And(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + dst->lo = a->lo & b->lo; + dst->hi = a->hi & b->hi; + R128_DEBUG_SET(dst); +} + +void r128Xor(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + dst->lo = a->lo ^ b->lo; + dst->hi = a->hi ^ b->hi; + R128_DEBUG_SET(dst); +} + +void r128Shl(R128 *dst, const R128 *src, int amount) +{ + R128_U64 r[4]; + + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + // load src + mov edx, dword ptr[src] + mov ecx, amount + + mov edi, dword ptr[edx] + mov esi, dword ptr[edx + 4] + mov ebx, dword ptr[edx + 8] + mov eax, dword ptr[edx + 12] + + // shift mod 32 + shld eax, ebx, cl + shld ebx, esi, cl + shld esi, edi, cl + shl edi, cl + + // clear out low 12 bytes of stack + xor edx, edx + mov dword ptr[r], edx + mov dword ptr[r + 4], edx + mov dword ptr[r + 8], edx + + // store shifted amount offset by count/32 bits + shr ecx, 5 + and ecx, 3 + mov dword ptr[r + ecx * 4 + 0], edi + mov dword ptr[r + ecx * 4 + 4], esi + mov dword ptr[r + ecx * 4 + 8], ebx + mov dword ptr[r + ecx * 4 + 12], eax + } +#else + + r[0] = src->lo; + r[1] = src->hi; + + amount &= 127; + if (amount >= 64) { + r[1] = r[0] << (amount - 64); + r[0] = 0; + } else if (amount) { +# ifdef _M_X64 + r[1] = __shiftleft128(r[0], r[1], (char) amount); +# else + r[1] = (r[1] << amount) | (r[0] >> (64 - amount)); +# endif + r[0] = r[0] << amount; + } +#endif //_M_IX86 + + dst->lo = r[0]; + dst->hi = r[1]; + R128_DEBUG_SET(dst); +} + +void r128Shr(R128 *dst, const R128 *src, int amount) +{ + R128_U64 r[4]; + + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + // load src + mov edx, dword ptr[src] + mov ecx, amount + + mov edi, dword ptr[edx] + mov esi, dword ptr[edx + 4] + mov ebx, dword ptr[edx + 8] + mov eax, dword ptr[edx + 12] + + // shift mod 32 + shrd edi, esi, cl + shrd esi, ebx, cl + shrd ebx, eax, cl + shr eax, cl + + // clear out high 12 bytes of stack + xor edx, edx + mov dword ptr[r + 20], edx + mov dword ptr[r + 24], edx + mov dword ptr[r + 28], edx + + // store shifted amount offset by -count/32 bits + shr ecx, 5 + and ecx, 3 + neg ecx + mov dword ptr[r + ecx * 4 + 16], edi + mov dword ptr[r + ecx * 4 + 20], esi + mov dword ptr[r + ecx * 4 + 24], ebx + mov dword ptr[r + ecx * 4 + 28], eax + } +#else + r[2] = src->lo; + r[3] = src->hi; + + amount &= 127; + if (amount >= 64) { + r[2] = r[3] >> (amount - 64); + r[3] = 0; + } else if (amount) { +#ifdef _M_X64 + r[2] = __shiftright128(r[2], r[3], (char) amount); +#else + r[2] = (r[2] >> amount) | (r[3] << (64 - amount)); +#endif + r[3] = r[3] >> amount; + } +#endif + + dst->lo = r[2]; + dst->hi = r[3]; + R128_DEBUG_SET(dst); +} + +void r128Sar(R128 *dst, const R128 *src, int amount) +{ + R128_U64 r[4]; + + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + // load src + mov edx, dword ptr[src] + mov ecx, amount + + mov edi, dword ptr[edx] + mov esi, dword ptr[edx + 4] + mov ebx, dword ptr[edx + 8] + mov eax, dword ptr[edx + 12] + + // shift mod 32 + shrd edi, esi, cl + shrd esi, ebx, cl + shrd ebx, eax, cl + sar eax, cl + + // copy sign to high 12 bytes of stack + cdq + mov dword ptr[r + 20], edx + mov dword ptr[r + 24], edx + mov dword ptr[r + 28], edx + + // store shifted amount offset by -count/32 bits + shr ecx, 5 + and ecx, 3 + neg ecx + mov dword ptr[r + ecx * 4 + 16], edi + mov dword ptr[r + ecx * 4 + 20], esi + mov dword ptr[r + ecx * 4 + 24], ebx + mov dword ptr[r + ecx * 4 + 28], eax + } +#else + r[2] = src->lo; + r[3] = src->hi; + + amount &= 127; + if (amount >= 64) { + r[2] = (R128_U64)((R128_S64)r[3] >> (amount - 64)); + r[3] = (R128_U64)((R128_S64)r[3] >> 63); + } else if (amount) { + r[2] = (r[2] >> amount) | (R128_U64)((R128_S64)r[3] << (64 - amount)); + r[3] = (R128_U64)((R128_S64)r[3] >> amount); + } +#endif + + dst->lo = r[2]; + dst->hi = r[3]; + R128_DEBUG_SET(dst); +} + +void r128Add(R128 *dst, const R128 *a, const R128 *b) +{ + unsigned char carry = 0; + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + +#if R128_INTEL && !defined(R128_STDC_ONLY) +# if R128_64BIT + carry = _addcarry_u64(carry, a->lo, b->lo, &dst->lo); + carry = _addcarry_u64(carry, a->hi, b->hi, &dst->hi); +# else + R128_U32 r0, r1, r2, r3; + carry = _addcarry_u32(carry, R128_R0(a), R128_R0(b), &r0); + carry = _addcarry_u32(carry, R128_R1(a), R128_R1(b), &r1); + carry = _addcarry_u32(carry, R128_R2(a), R128_R2(b), &r2); + carry = _addcarry_u32(carry, R128_R3(a), R128_R3(b), &r3); + R128_SET4(dst, r0, r1, r2, r3); +# endif //R128_64BIT +#else + { + R128_U64 r = a->lo + b->lo; + carry = r < a->lo; + dst->lo = r; + dst->hi = a->hi + b->hi + carry; + } +#endif //R128_INTEL + + R128_DEBUG_SET(dst); +} + +void r128Sub(R128 *dst, const R128 *a, const R128 *b) +{ + unsigned char borrow = 0; + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + +#if R128_INTEL && !defined(R128_STDC_ONLY) +# if R128_64BIT + borrow = _subborrow_u64(borrow, a->lo, b->lo, &dst->lo); + borrow = _subborrow_u64(borrow, a->hi, b->hi, &dst->hi); +# else + R128_U32 r0, r1, r2, r3; + borrow = _subborrow_u32(borrow, R128_R0(a), R128_R0(b), &r0); + borrow = _subborrow_u32(borrow, R128_R1(a), R128_R1(b), &r1); + borrow = _subborrow_u32(borrow, R128_R2(a), R128_R2(b), &r2); + borrow = _subborrow_u32(borrow, R128_R3(a), R128_R3(b), &r3); + R128_SET4(dst, r0, r1, r2, r3); +# endif //R128_64BIT +#else + { + R128_U64 r = a->lo - b->lo; + borrow = r > a->lo; + dst->lo = r; + dst->hi = a->hi - b->hi - borrow; + } +#endif //R128_INTEL + + R128_DEBUG_SET(dst); +} + +void r128Mul(R128 *dst, const R128 *a, const R128 *b) +{ + int sign = 0; + R128 ta, tb, tc; + + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + R128_SET2(&ta, a->lo, a->hi); + R128_SET2(&tb, b->lo, b->hi); + + if (r128IsNeg(&ta)) { + r128__neg(&ta, &ta); + sign = !sign; + } + if (r128IsNeg(&tb)) { + r128__neg(&tb, &tb); + sign = !sign; + } + + r128__umul(&tc, &ta, &tb); + if (sign) { + r128__neg(&tc, &tc); + } + + r128Copy(dst, &tc); +} + +void r128Div(R128 *dst, const R128 *a, const R128 *b) +{ + int sign = 0; + R128 tn, td, tq; + + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + R128_SET2(&tn, a->lo, a->hi); + R128_SET2(&td, b->lo, b->hi); + + if (r128IsNeg(&tn)) { + r128__neg(&tn, &tn); + sign = !sign; + } + + if (td.lo == 0 && td.hi == 0) { + // divide by zero + if (sign) { + r128Copy(dst, &R128_min); + } else { + r128Copy(dst, &R128_max); + } + return; + } else if (r128IsNeg(&td)) { + r128__neg(&td, &td); + sign = !sign; + } + + r128__udiv(&tq, &tn, &td); + + if (sign) { + r128__neg(&tq, &tq); + } + + r128Copy(dst, &tq); +} + +void r128Mod(R128 *dst, const R128 *a, const R128 *b) +{ + int sign = 0; + R128 tn, td, tq; + + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + R128_SET2(&tn, a->lo, a->hi); + R128_SET2(&td, b->lo, b->hi); + + if (r128IsNeg(&tn)) { + r128__neg(&tn, &tn); + sign = !sign; + } + + if (td.lo == 0 && td.hi == 0) { + // divide by zero + if (sign) { + r128Copy(dst, &R128_min); + } else { + r128Copy(dst, &R128_max); + } + return; + } else if (r128IsNeg(&td)) { + r128__neg(&td, &td); + sign = !sign; + } + + tq.hi = r128__umod(&tn, &td); + tq.lo = 0; + + if (sign) { + tq.hi = ~tq.hi + 1; + } + + r128Mul(&tq, &tq, b); + r128Sub(dst, a, &tq); +} + +void r128Rsqrt(R128 *dst, const R128 *v) +{ + static const R128 threeHalves = { R128_LIT_U64(0x8000000000000000), 1 }; + R128 x, est; + int i; + + if ((R128_S64)v->hi < 0) { + r128Copy(dst, &R128_min); + return; + } + + R128_SET2(&x, v->lo, v->hi); + + // get initial estimate + if (x.hi) { + int shift = (64 + r128__clz64(x.hi)) >> 1; + est.lo = R128_LIT_U64(1) << shift; + est.hi = 0; + } else if (x.lo) { + int shift = r128__clz64(x.lo) >> 1; + est.hi = R128_LIT_U64(1) << shift; + est.lo = 0; + } else { + R128_SET2(dst, 0, 0); + return; + } + + // x /= 2 + r128Shr(&x, &x, 1); + + // Newton-Raphson iterate + for (i = 0; i < 7; ++i) { + R128 newEst; + + // newEst = est * (threeHalves - (x / 2) * est * est); + r128__umul(&newEst, &est, &est); + r128__umul(&newEst, &newEst, &x); + r128Sub(&newEst, &threeHalves, &newEst); + r128__umul(&newEst, &est, &newEst); + + if (newEst.lo == est.lo && newEst.hi == est.hi) { + break; + } + R128_SET2(&est, newEst.lo, newEst.hi); + } + + r128Copy(dst, &est); +} + +void r128Sqrt(R128 *dst, const R128 *v) +{ + R128 x, est; + int i; + + if ((R128_S64)v->hi < 0) { + r128Copy(dst, &R128_min); + return; + } + + R128_SET2(&x, v->lo, v->hi); + + // get initial estimate + if (x.hi) { + int shift = (63 - r128__clz64(x.hi)) >> 1; + r128Shr(&est, &x, shift); + } else if (x.lo) { + int shift = (1 + r128__clz64(x.lo)) >> 1; + r128Shl(&est, &x, shift); + } else { + R128_SET2(dst, 0, 0); + return; + } + + // Newton-Raphson iterate + for (i = 0; i < 7; ++i) { + R128 newEst; + + // newEst = (est + x / est) / 2 + r128__udiv(&newEst, &x, &est); + r128Add(&newEst, &newEst, &est); + r128Shr(&newEst, &newEst, 1); + + if (newEst.lo == est.lo && newEst.hi == est.hi) { + break; + } + R128_SET2(&est, newEst.lo, newEst.hi); + } + + r128Copy(dst, &est); +} + +int r128Cmp(const R128 *a, const R128 *b) +{ + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + if (a->hi == b->hi) { + if (a->lo == b->lo) { + return 0; + } else if (a->lo > b->lo) { + return 1; + } else { + return -1; + } + } else if ((R128_S64)a->hi > (R128_S64)b->hi) { + return 1; + } else { + return -1; + } +} + +int r128IsNeg(const R128 *v) +{ + R128_ASSERT(v != NULL); + + return (R128_S64)v->hi < 0; +} + +void r128Min(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + if (r128Cmp(a, b) < 0) { + r128Copy(dst, a); + } else { + r128Copy(dst, b); + } +} + +void r128Max(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + if (r128Cmp(a, b) > 0) { + r128Copy(dst, a); + } else { + r128Copy(dst, b); + } +} + +void r128Floor(R128 *dst, const R128 *v) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(v != NULL); + + if ((R128_S64)v->hi < 0) { + dst->hi = v->hi - (v->lo != 0); + } else { + dst->hi = v->hi; + } + dst->lo = 0; + R128_DEBUG_SET(dst); +} + +void r128Ceil(R128 *dst, const R128 *v) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(v != NULL); + + if ((R128_S64)v->hi > 0) { + dst->hi = v->hi + (v->lo != 0); + } else { + dst->hi = v->hi; + } + dst->lo = 0; + R128_DEBUG_SET(dst); +} + +#endif //R128_IMPLEMENTATION + diff --git a/thirdparty/stb_rect_pack/stb_rect_pack.h b/thirdparty/stb_rect_pack/stb_rect_pack.h new file mode 100644 index 0000000000..3336fe7395 --- /dev/null +++ b/thirdparty/stb_rect_pack/stb_rect_pack.h @@ -0,0 +1,629 @@ +// stb_rect_pack.h - v1.00 - public domain - rectangle packing +// Sean Barrett 2014 +// +// Useful for e.g. packing rectangular textures into an atlas. +// Does not do rotation. +// +// Not necessarily the awesomest packing method, but better than +// the totally naive one in stb_truetype (which is primarily what +// this is meant to replace). +// +// Has only had a few tests run, may have issues. +// +// More docs to come. +// +// No memory allocations; uses qsort() and assert() from stdlib. +// Can override those by defining STBRP_SORT and STBRP_ASSERT. +// +// This library currently uses the Skyline Bottom-Left algorithm. +// +// Please note: better rectangle packers are welcome! Please +// implement them to the same API, but with a different init +// function. +// +// Credits +// +// Library +// Sean Barrett +// Minor features +// Martins Mozeiko +// github:IntellectualKitty +// +// Bugfixes / warning fixes +// Jeremy Jaussaud +// Fabian Giesen +// +// Version history: +// +// 1.00 (2019-02-25) avoid small space waste; gracefully fail too-wide rectangles +// 0.99 (2019-02-07) warning fixes +// 0.11 (2017-03-03) return packing success/fail result +// 0.10 (2016-10-25) remove cast-away-const to avoid warnings +// 0.09 (2016-08-27) fix compiler warnings +// 0.08 (2015-09-13) really fix bug with empty rects (w=0 or h=0) +// 0.07 (2015-09-13) fix bug with empty rects (w=0 or h=0) +// 0.06 (2015-04-15) added STBRP_SORT to allow replacing qsort +// 0.05: added STBRP_ASSERT to allow replacing assert +// 0.04: fixed minor bug in STBRP_LARGE_RECTS support +// 0.01: initial release +// +// LICENSE +// +// See end of file for license information. + +////////////////////////////////////////////////////////////////////////////// +// +// INCLUDE SECTION +// + +#ifndef STB_INCLUDE_STB_RECT_PACK_H +#define STB_INCLUDE_STB_RECT_PACK_H + +#define STB_RECT_PACK_VERSION 1 + +#ifdef STBRP_STATIC +#define STBRP_DEF static +#else +#define STBRP_DEF extern +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct stbrp_context stbrp_context; +typedef struct stbrp_node stbrp_node; +typedef struct stbrp_rect stbrp_rect; + +#ifdef STBRP_LARGE_RECTS +typedef int stbrp_coord; +#else +typedef unsigned short stbrp_coord; +#endif + +STBRP_DEF int stbrp_pack_rects (stbrp_context *context, stbrp_rect *rects, int num_rects); +// Assign packed locations to rectangles. The rectangles are of type +// 'stbrp_rect' defined below, stored in the array 'rects', and there +// are 'num_rects' many of them. +// +// Rectangles which are successfully packed have the 'was_packed' flag +// set to a non-zero value and 'x' and 'y' store the minimum location +// on each axis (i.e. bottom-left in cartesian coordinates, top-left +// if you imagine y increasing downwards). Rectangles which do not fit +// have the 'was_packed' flag set to 0. +// +// You should not try to access the 'rects' array from another thread +// while this function is running, as the function temporarily reorders +// the array while it executes. +// +// To pack into another rectangle, you need to call stbrp_init_target +// again. To continue packing into the same rectangle, you can call +// this function again. Calling this multiple times with multiple rect +// arrays will probably produce worse packing results than calling it +// a single time with the full rectangle array, but the option is +// available. +// +// The function returns 1 if all of the rectangles were successfully +// packed and 0 otherwise. + +struct stbrp_rect +{ + // reserved for your use: + int id; + + // input: + stbrp_coord w, h; + + // output: + stbrp_coord x, y; + int was_packed; // non-zero if valid packing + +}; // 16 bytes, nominally + + +STBRP_DEF void stbrp_init_target (stbrp_context *context, int width, int height, stbrp_node *nodes, int num_nodes); +// Initialize a rectangle packer to: +// pack a rectangle that is 'width' by 'height' in dimensions +// using temporary storage provided by the array 'nodes', which is 'num_nodes' long +// +// You must call this function every time you start packing into a new target. +// +// There is no "shutdown" function. The 'nodes' memory must stay valid for +// the following stbrp_pack_rects() call (or calls), but can be freed after +// the call (or calls) finish. +// +// Note: to guarantee best results, either: +// 1. make sure 'num_nodes' >= 'width' +// or 2. call stbrp_allow_out_of_mem() defined below with 'allow_out_of_mem = 1' +// +// If you don't do either of the above things, widths will be quantized to multiples +// of small integers to guarantee the algorithm doesn't run out of temporary storage. +// +// If you do #2, then the non-quantized algorithm will be used, but the algorithm +// may run out of temporary storage and be unable to pack some rectangles. + +STBRP_DEF void stbrp_setup_allow_out_of_mem (stbrp_context *context, int allow_out_of_mem); +// Optionally call this function after init but before doing any packing to +// change the handling of the out-of-temp-memory scenario, described above. +// If you call init again, this will be reset to the default (false). + + +STBRP_DEF void stbrp_setup_heuristic (stbrp_context *context, int heuristic); +// Optionally select which packing heuristic the library should use. Different +// heuristics will produce better/worse results for different data sets. +// If you call init again, this will be reset to the default. + +enum +{ + STBRP_HEURISTIC_Skyline_default=0, + STBRP_HEURISTIC_Skyline_BL_sortHeight = STBRP_HEURISTIC_Skyline_default, + STBRP_HEURISTIC_Skyline_BF_sortHeight +}; + + +////////////////////////////////////////////////////////////////////////////// +// +// the details of the following structures don't matter to you, but they must +// be visible so you can handle the memory allocations for them + +struct stbrp_node +{ + stbrp_coord x,y; + stbrp_node *next; +}; + +struct stbrp_context +{ + int width; + int height; + int align; + int init_mode; + int heuristic; + int num_nodes; + stbrp_node *active_head; + stbrp_node *free_head; + stbrp_node extra[2]; // we allocate two extra nodes so optimal user-node-count is 'width' not 'width+2' +}; + +#ifdef __cplusplus +} +#endif + +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// IMPLEMENTATION SECTION +// + +#ifdef STB_RECT_PACK_IMPLEMENTATION +#ifndef STBRP_SORT +#include +#define STBRP_SORT qsort +#endif + +#ifndef STBRP_ASSERT +#include +#define STBRP_ASSERT assert +#endif + +#ifdef _MSC_VER +#define STBRP__NOTUSED(v) (void)(v) +#else +#define STBRP__NOTUSED(v) (void)sizeof(v) +#endif + +enum +{ + STBRP__INIT_skyline = 1 +}; + +STBRP_DEF void stbrp_setup_heuristic(stbrp_context *context, int heuristic) +{ + switch (context->init_mode) { + case STBRP__INIT_skyline: + STBRP_ASSERT(heuristic == STBRP_HEURISTIC_Skyline_BL_sortHeight || heuristic == STBRP_HEURISTIC_Skyline_BF_sortHeight); + context->heuristic = heuristic; + break; + default: + STBRP_ASSERT(0); + } +} + +STBRP_DEF void stbrp_setup_allow_out_of_mem(stbrp_context *context, int allow_out_of_mem) +{ + if (allow_out_of_mem) + // if it's ok to run out of memory, then don't bother aligning them; + // this gives better packing, but may fail due to OOM (even though + // the rectangles easily fit). @TODO a smarter approach would be to only + // quantize once we've hit OOM, then we could get rid of this parameter. + context->align = 1; + else { + // if it's not ok to run out of memory, then quantize the widths + // so that num_nodes is always enough nodes. + // + // I.e. num_nodes * align >= width + // align >= width / num_nodes + // align = ceil(width/num_nodes) + + context->align = (context->width + context->num_nodes-1) / context->num_nodes; + } +} + +STBRP_DEF void stbrp_init_target(stbrp_context *context, int width, int height, stbrp_node *nodes, int num_nodes) +{ + int i; +#ifndef STBRP_LARGE_RECTS + STBRP_ASSERT(width <= 0xffff && height <= 0xffff); +#endif + + for (i=0; i < num_nodes-1; ++i) + nodes[i].next = &nodes[i+1]; + nodes[i].next = NULL; + context->init_mode = STBRP__INIT_skyline; + context->heuristic = STBRP_HEURISTIC_Skyline_default; + context->free_head = &nodes[0]; + context->active_head = &context->extra[0]; + context->width = width; + context->height = height; + context->num_nodes = num_nodes; + stbrp_setup_allow_out_of_mem(context, 0); + + // node 0 is the full width, node 1 is the sentinel (lets us not store width explicitly) + context->extra[0].x = 0; + context->extra[0].y = 0; + context->extra[0].next = &context->extra[1]; + context->extra[1].x = (stbrp_coord) width; +#ifdef STBRP_LARGE_RECTS + context->extra[1].y = (1<<30); +#else + context->extra[1].y = 65535; +#endif + context->extra[1].next = NULL; +} + +// find minimum y position if it starts at x1 +static int stbrp__skyline_find_min_y(stbrp_context *c, stbrp_node *first, int x0, int width, int *pwaste) +{ + stbrp_node *node = first; + int x1 = x0 + width; + int min_y, visited_width, waste_area; + + STBRP__NOTUSED(c); + + STBRP_ASSERT(first->x <= x0); + + #if 0 + // skip in case we're past the node + while (node->next->x <= x0) + ++node; + #else + STBRP_ASSERT(node->next->x > x0); // we ended up handling this in the caller for efficiency + #endif + + STBRP_ASSERT(node->x <= x0); + + min_y = 0; + waste_area = 0; + visited_width = 0; + while (node->x < x1) { + if (node->y > min_y) { + // raise min_y higher. + // we've accounted for all waste up to min_y, + // but we'll now add more waste for everything we've visted + waste_area += visited_width * (node->y - min_y); + min_y = node->y; + // the first time through, visited_width might be reduced + if (node->x < x0) + visited_width += node->next->x - x0; + else + visited_width += node->next->x - node->x; + } else { + // add waste area + int under_width = node->next->x - node->x; + if (under_width + visited_width > width) + under_width = width - visited_width; + waste_area += under_width * (min_y - node->y); + visited_width += under_width; + } + node = node->next; + } + + *pwaste = waste_area; + return min_y; +} + +typedef struct +{ + int x,y; + stbrp_node **prev_link; +} stbrp__findresult; + +static stbrp__findresult stbrp__skyline_find_best_pos(stbrp_context *c, int width, int height) +{ + int best_waste = (1<<30), best_x, best_y = (1 << 30); + stbrp__findresult fr; + stbrp_node **prev, *node, *tail, **best = NULL; + + // align to multiple of c->align + width = (width + c->align - 1); + width -= width % c->align; + STBRP_ASSERT(width % c->align == 0); + + // if it can't possibly fit, bail immediately + if (width > c->width || height > c->height) { + fr.prev_link = NULL; + fr.x = fr.y = 0; + return fr; + } + + node = c->active_head; + prev = &c->active_head; + while (node->x + width <= c->width) { + int y,waste; + y = stbrp__skyline_find_min_y(c, node, node->x, width, &waste); + if (c->heuristic == STBRP_HEURISTIC_Skyline_BL_sortHeight) { // actually just want to test BL + // bottom left + if (y < best_y) { + best_y = y; + best = prev; + } + } else { + // best-fit + if (y + height <= c->height) { + // can only use it if it first vertically + if (y < best_y || (y == best_y && waste < best_waste)) { + best_y = y; + best_waste = waste; + best = prev; + } + } + } + prev = &node->next; + node = node->next; + } + + best_x = (best == NULL) ? 0 : (*best)->x; + + // if doing best-fit (BF), we also have to try aligning right edge to each node position + // + // e.g, if fitting + // + // ____________________ + // |____________________| + // + // into + // + // | | + // | ____________| + // |____________| + // + // then right-aligned reduces waste, but bottom-left BL is always chooses left-aligned + // + // This makes BF take about 2x the time + + if (c->heuristic == STBRP_HEURISTIC_Skyline_BF_sortHeight) { + tail = c->active_head; + node = c->active_head; + prev = &c->active_head; + // find first node that's admissible + while (tail->x < width) + tail = tail->next; + while (tail) { + int xpos = tail->x - width; + int y,waste; + STBRP_ASSERT(xpos >= 0); + // find the left position that matches this + while (node->next->x <= xpos) { + prev = &node->next; + node = node->next; + } + STBRP_ASSERT(node->next->x > xpos && node->x <= xpos); + y = stbrp__skyline_find_min_y(c, node, xpos, width, &waste); + if (y + height <= c->height) { + if (y <= best_y) { + if (y < best_y || waste < best_waste || (waste==best_waste && xpos < best_x)) { + best_x = xpos; + STBRP_ASSERT(y <= best_y); + best_y = y; + best_waste = waste; + best = prev; + } + } + } + tail = tail->next; + } + } + + fr.prev_link = best; + fr.x = best_x; + fr.y = best_y; + return fr; +} + +static stbrp__findresult stbrp__skyline_pack_rectangle(stbrp_context *context, int width, int height) +{ + // find best position according to heuristic + stbrp__findresult res = stbrp__skyline_find_best_pos(context, width, height); + stbrp_node *node, *cur; + + // bail if: + // 1. it failed + // 2. the best node doesn't fit (we don't always check this) + // 3. we're out of memory + if (res.prev_link == NULL || res.y + height > context->height || context->free_head == NULL) { + res.prev_link = NULL; + return res; + } + + // on success, create new node + node = context->free_head; + node->x = (stbrp_coord) res.x; + node->y = (stbrp_coord) (res.y + height); + + context->free_head = node->next; + + // insert the new node into the right starting point, and + // let 'cur' point to the remaining nodes needing to be + // stiched back in + + cur = *res.prev_link; + if (cur->x < res.x) { + // preserve the existing one, so start testing with the next one + stbrp_node *next = cur->next; + cur->next = node; + cur = next; + } else { + *res.prev_link = node; + } + + // from here, traverse cur and free the nodes, until we get to one + // that shouldn't be freed + while (cur->next && cur->next->x <= res.x + width) { + stbrp_node *next = cur->next; + // move the current node to the free list + cur->next = context->free_head; + context->free_head = cur; + cur = next; + } + + // stitch the list back in + node->next = cur; + + if (cur->x < res.x + width) + cur->x = (stbrp_coord) (res.x + width); + +#ifdef _DEBUG + cur = context->active_head; + while (cur->x < context->width) { + STBRP_ASSERT(cur->x < cur->next->x); + cur = cur->next; + } + STBRP_ASSERT(cur->next == NULL); + + { + int count=0; + cur = context->active_head; + while (cur) { + cur = cur->next; + ++count; + } + cur = context->free_head; + while (cur) { + cur = cur->next; + ++count; + } + STBRP_ASSERT(count == context->num_nodes+2); + } +#endif + + return res; +} + +static int rect_height_compare(const void *a, const void *b) +{ + const stbrp_rect *p = (const stbrp_rect *) a; + const stbrp_rect *q = (const stbrp_rect *) b; + if (p->h > q->h) + return -1; + if (p->h < q->h) + return 1; + return (p->w > q->w) ? -1 : (p->w < q->w); +} + +static int rect_original_order(const void *a, const void *b) +{ + const stbrp_rect *p = (const stbrp_rect *) a; + const stbrp_rect *q = (const stbrp_rect *) b; + return (p->was_packed < q->was_packed) ? -1 : (p->was_packed > q->was_packed); +} + +#ifdef STBRP_LARGE_RECTS +#define STBRP__MAXVAL 0xffffffff +#else +#define STBRP__MAXVAL 0xffff +#endif + +STBRP_DEF int stbrp_pack_rects(stbrp_context *context, stbrp_rect *rects, int num_rects) +{ + int i, all_rects_packed = 1; + + // we use the 'was_packed' field internally to allow sorting/unsorting + for (i=0; i < num_rects; ++i) { + rects[i].was_packed = i; + } + + // sort according to heuristic + STBRP_SORT(rects, num_rects, sizeof(rects[0]), rect_height_compare); + + for (i=0; i < num_rects; ++i) { + if (rects[i].w == 0 || rects[i].h == 0) { + rects[i].x = rects[i].y = 0; // empty rect needs no space + } else { + stbrp__findresult fr = stbrp__skyline_pack_rectangle(context, rects[i].w, rects[i].h); + if (fr.prev_link) { + rects[i].x = (stbrp_coord) fr.x; + rects[i].y = (stbrp_coord) fr.y; + } else { + rects[i].x = rects[i].y = STBRP__MAXVAL; + } + } + } + + // unsort + STBRP_SORT(rects, num_rects, sizeof(rects[0]), rect_original_order); + + // set was_packed flags and all_rects_packed status + for (i=0; i < num_rects; ++i) { + rects[i].was_packed = !(rects[i].x == STBRP__MAXVAL && rects[i].y == STBRP__MAXVAL); + if (!rects[i].was_packed) + all_rects_packed = 0; + } + + // return the all_rects_packed status + return all_rects_packed; +} +#endif + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ + diff --git a/thirdparty/xatlas/xatlas.cpp b/thirdparty/xatlas/xatlas.cpp index 80cacf9746..b1cbeb980f 100644 --- a/thirdparty/xatlas/xatlas.cpp +++ b/thirdparty/xatlas/xatlas.cpp @@ -33,19 +33,19 @@ https://github.com/brandonpelfrey/Fast-BVH MIT License Copyright (c) 2012 Brandon Pelfrey */ -#include -#include -#include -#include #include #include // FLT_MAX #include #include +#include +#include +#include +#include #define __STDC_LIMIT_MACROS +#include "xatlas.h" #include #include #include -#include "xatlas.h" #ifndef XA_DEBUG #ifdef NDEBUG @@ -70,7 +70,10 @@ Copyright (c) 2012 Brandon Pelfrey #define XA_XSTR(x) XA_STR(x) #ifndef XA_ASSERT -#define XA_ASSERT(exp) if (!(exp)) { XA_PRINT_WARNING("\rASSERT: %s %s %d\n", XA_XSTR(exp), __FILE__, __LINE__); } +#define XA_ASSERT(exp) \ + if (!(exp)) { \ + XA_PRINT_WARNING("\rASSERT: %s %s %d\n", XA_XSTR(exp), __FILE__, __LINE__); \ + } #endif #ifndef XA_DEBUG_ASSERT @@ -78,13 +81,13 @@ Copyright (c) 2012 Brandon Pelfrey #endif #ifndef XA_PRINT -#define XA_PRINT(...) \ +#define XA_PRINT(...) \ if (xatlas::internal::s_print && xatlas::internal::s_printVerbose) \ xatlas::internal::s_print(__VA_ARGS__); #endif #ifndef XA_PRINT_WARNING -#define XA_PRINT_WARNING(...) \ +#define XA_PRINT_WARNING(...) \ if (xatlas::internal::s_print) \ xatlas::internal::s_print(__VA_ARGS__); #endif @@ -136,19 +139,13 @@ Copyright (c) 2012 Brandon Pelfrey #define XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION 0 #define XA_DEBUG_EXPORT_OBJ_RECOMPUTED_CHARTS 0 -#define XA_DEBUG_EXPORT_OBJ (0 \ - || XA_DEBUG_EXPORT_OBJ_SOURCE_MESHES \ - || XA_DEBUG_EXPORT_OBJ_CHART_GROUPS \ - || XA_DEBUG_EXPORT_OBJ_PLANAR_REGIONS \ - || XA_DEBUG_EXPORT_OBJ_CHARTS \ - || XA_DEBUG_EXPORT_OBJ_BEFORE_FIX_TJUNCTION \ - || XA_DEBUG_EXPORT_OBJ_CLOSE_HOLES_ERROR \ - || XA_DEBUG_EXPORT_OBJ_CHARTS_AFTER_PARAMETERIZATION \ - || XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION \ - || XA_DEBUG_EXPORT_OBJ_RECOMPUTED_CHARTS) +#define XA_DEBUG_EXPORT_OBJ (0 || XA_DEBUG_EXPORT_OBJ_SOURCE_MESHES || XA_DEBUG_EXPORT_OBJ_CHART_GROUPS || XA_DEBUG_EXPORT_OBJ_PLANAR_REGIONS || XA_DEBUG_EXPORT_OBJ_CHARTS || XA_DEBUG_EXPORT_OBJ_BEFORE_FIX_TJUNCTION || XA_DEBUG_EXPORT_OBJ_CLOSE_HOLES_ERROR || XA_DEBUG_EXPORT_OBJ_CHARTS_AFTER_PARAMETERIZATION || XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION || XA_DEBUG_EXPORT_OBJ_RECOMPUTED_CHARTS) #ifdef _MSC_VER -#define XA_FOPEN(_file, _filename, _mode) { if (fopen_s(&_file, _filename, _mode) != 0) _file = NULL; } +#define XA_FOPEN(_file, _filename, _mode) \ + { \ + if (fopen_s(&_file, _filename, _mode) != 0) _file = NULL; \ + } #define XA_SPRINTF(_buffer, _size, _format, ...) sprintf_s(_buffer, _size, _format, __VA_ARGS__) #else #define XA_FOPEN(_file, _filename, _mode) _file = fopen(_filename, _mode) @@ -163,10 +160,8 @@ static FreeFunc s_free = free; static PrintFunc s_print = printf; static bool s_printVerbose = false; -struct MemTag -{ - enum - { +struct MemTag { + enum { Default, BitImage, BVH, @@ -189,8 +184,7 @@ struct MemTag }; #if XA_DEBUG_HEAP -struct AllocHeader -{ +struct AllocHeader { size_t size; const char *file; int line; @@ -203,11 +197,10 @@ struct AllocHeader static std::mutex s_allocMutex; static AllocHeader *s_allocRoot = nullptr; static size_t s_allocTotalCount = 0, s_allocTotalSize = 0, s_allocPeakSize = 0, s_allocCount[MemTag::Count] = { 0 }, s_allocTotalTagSize[MemTag::Count] = { 0 }, s_allocPeakTagSize[MemTag::Count] = { 0 }; -static uint32_t s_allocId =0 ; +static uint32_t s_allocId = 0; static constexpr uint32_t kAllocRedzone = 0x12345678; -static void *Realloc(void *ptr, size_t size, int tag, const char *file, int line) -{ +static void *Realloc(void *ptr, size_t size, int tag, const char *file, int line) { std::unique_lock lock(s_allocMutex); if (!size && !ptr) return nullptr; @@ -268,8 +261,7 @@ static void *Realloc(void *ptr, size_t size, int tag, const char *file, int line return newPtr + sizeof(AllocHeader); } -static void ReportLeaks() -{ +static void ReportLeaks() { printf("Checking for memory leaks...\n"); bool anyLeaks = false; AllocHeader *header = s_allocRoot; @@ -297,8 +289,7 @@ static void ReportLeaks() s_allocTotalTagSize[i] = s_allocPeakTagSize[i] = 0; } -static void PrintMemoryUsage() -{ +static void PrintMemoryUsage() { XA_PRINT("Total allocations: %zu\n", s_allocTotalCount); XA_PRINT("Memory usage: %0.2fMB current, %0.2fMB peak\n", internal::s_allocTotalSize / 1024.0f / 1024.0f, internal::s_allocPeakSize / 1024.0f / 1024.0f); static const char *labels[] = { // Sync with MemTag @@ -327,8 +318,7 @@ static void PrintMemoryUsage() #define XA_PRINT_MEM_USAGE internal::PrintMemoryUsage(); #else -static void *Realloc(void *ptr, size_t size, int /*tag*/, const char * /*file*/, int /*line*/) -{ +static void *Realloc(void *ptr, size_t size, int /*tag*/, const char * /*file*/, int /*line*/) { if (size == 0 && !ptr) return nullptr; if (size == 0 && s_free) { @@ -347,10 +337,11 @@ static void *Realloc(void *ptr, size_t size, int /*tag*/, const char * /*file*/, #if XA_PROFILE #define XA_PROFILE_START(var) const clock_t var##Start = clock(); #define XA_PROFILE_END(var) internal::s_profile.var += clock() - var##Start; -#define XA_PROFILE_PRINT_AND_RESET(label, var) XA_PRINT("%s%.2f seconds (%g ms)\n", label, internal::clockToSeconds(internal::s_profile.var), internal::clockToMs(internal::s_profile.var)); internal::s_profile.var = 0; +#define XA_PROFILE_PRINT_AND_RESET(label, var) \ + XA_PRINT("%s%.2f seconds (%g ms)\n", label, internal::clockToSeconds(internal::s_profile.var), internal::clockToMs(internal::s_profile.var)); \ + internal::s_profile.var = 0; -struct ProfileData -{ +struct ProfileData { clock_t addMeshReal; clock_t addMeshCopyData; std::atomic addMeshThread; @@ -390,13 +381,11 @@ struct ProfileData static ProfileData s_profile; -static double clockToMs(clock_t c) -{ +static double clockToMs(clock_t c) { return c * 1000.0 / CLOCKS_PER_SEC; } -static double clockToSeconds(clock_t c) -{ +static double clockToSeconds(clock_t c) { return c / (double)CLOCKS_PER_SEC; } #else @@ -412,89 +401,75 @@ static constexpr float kEpsilon = 0.0001f; static constexpr float kAreaEpsilon = FLT_EPSILON; static constexpr float kNormalEpsilon = 0.001f; -static int align(int x, int a) -{ +static int align(int x, int a) { return (x + a - 1) & ~(a - 1); } template -static T max(const T &a, const T &b) -{ +static T max(const T &a, const T &b) { return a > b ? a : b; } template -static T min(const T &a, const T &b) -{ +static T min(const T &a, const T &b) { return a < b ? a : b; } template -static T max3(const T &a, const T &b, const T &c) -{ +static T max3(const T &a, const T &b, const T &c) { return max(a, max(b, c)); } /// Return the maximum of the three arguments. template -static T min3(const T &a, const T &b, const T &c) -{ +static T min3(const T &a, const T &b, const T &c) { return min(a, min(b, c)); } /// Clamp between two values. template -static T clamp(const T &x, const T &a, const T &b) -{ +static T clamp(const T &x, const T &a, const T &b) { return min(max(x, a), b); } template -static void swap(T &a, T &b) -{ +static void swap(T &a, T &b) { T temp = a; a = b; b = temp; } -union FloatUint32 -{ +union FloatUint32 { float f; uint32_t u; }; -static bool isFinite(float f) -{ +static bool isFinite(float f) { FloatUint32 fu; fu.f = f; return fu.u != 0x7F800000u && fu.u != 0x7F800001u; } -static bool isNan(float f) -{ +static bool isNan(float f) { return f != f; } // Robust floating point comparisons: // http://realtimecollisiondetection.net/blog/?p=89 -static bool equal(const float f0, const float f1, const float epsilon) -{ +static bool equal(const float f0, const float f1, const float epsilon) { //return fabs(f0-f1) <= epsilon; return fabs(f0 - f1) <= epsilon * max3(1.0f, fabsf(f0), fabsf(f1)); } -static int ftoi_ceil(float val) -{ +static int ftoi_ceil(float val) { return (int)ceilf(val); } -static bool isZero(const float f, const float epsilon) -{ +static bool isZero(const float f, const float epsilon) { return fabs(f) <= epsilon; } -static float square(float f) -{ +static float square(float f) { return f * f; } @@ -504,9 +479,8 @@ static float square(float f) * @note isPowerOfTwo(x) == true -> nextPowerOfTwo(x) == x * @note nextPowerOfTwo(x) = 2 << log2(x-1) */ -static uint32_t nextPowerOfTwo(uint32_t x) -{ - XA_DEBUG_ASSERT( x != 0 ); +static uint32_t nextPowerOfTwo(uint32_t x) { + XA_DEBUG_ASSERT(x != 0); // On modern CPUs this is supposed to be as fast as using the bsr instruction. x--; x |= x >> 1; @@ -517,65 +491,59 @@ static uint32_t nextPowerOfTwo(uint32_t x) return x + 1; } -static uint32_t sdbmHash(const void *data_in, uint32_t size, uint32_t h = 5381) -{ - const uint8_t *data = (const uint8_t *) data_in; +static uint32_t sdbmHash(const void *data_in, uint32_t size, uint32_t h = 5381) { + const uint8_t *data = (const uint8_t *)data_in; uint32_t i = 0; while (i < size) { - h = (h << 16) + (h << 6) - h + (uint32_t ) data[i++]; + h = (h << 16) + (h << 6) - h + (uint32_t)data[i++]; } return h; } template -static uint32_t hash(const T &t, uint32_t h = 5381) -{ +static uint32_t hash(const T &t, uint32_t h = 5381) { return sdbmHash(&t, sizeof(T), h); } // Functors for hash table: -template struct Hash -{ +template +struct Hash { uint32_t operator()(const Key &k) const { return hash(k); } }; -template struct Equal -{ +template +struct Equal { bool operator()(const Key &k0, const Key &k1) const { return k0 == k1; } }; -class Vector2 -{ +class Vector2 { public: Vector2() {} - explicit Vector2(float f) : x(f), y(f) {} - Vector2(float x, float y): x(x), y(y) {} + explicit Vector2(float f) : + x(f), y(f) {} + Vector2(float x, float y) : + x(x), y(y) {} - Vector2 operator-() const - { + Vector2 operator-() const { return Vector2(-x, -y); } - void operator+=(const Vector2 &v) - { + void operator+=(const Vector2 &v) { x += v.x; y += v.y; } - void operator-=(const Vector2 &v) - { + void operator-=(const Vector2 &v) { x -= v.x; y -= v.y; } - void operator*=(float s) - { + void operator*=(float s) { x *= s; y *= s; } - void operator*=(const Vector2 &v) - { + void operator*=(const Vector2 &v) { x *= v.x; y *= v.y; } @@ -583,13 +551,11 @@ public: float x, y; }; -static bool operator==(const Vector2 &a, const Vector2 &b) -{ +static bool operator==(const Vector2 &a, const Vector2 &b) { return a.x == b.x && a.y == b.y; } -static bool operator!=(const Vector2 &a, const Vector2 &b) -{ +static bool operator!=(const Vector2 &a, const Vector2 &b) { return a.x != b.x || a.y != b.y; } @@ -598,40 +564,33 @@ static bool operator!=(const Vector2 &a, const Vector2 &b) return Vector2(a.x + b.x, a.y + b.y); }*/ -static Vector2 operator-(const Vector2 &a, const Vector2 &b) -{ +static Vector2 operator-(const Vector2 &a, const Vector2 &b) { return Vector2(a.x - b.x, a.y - b.y); } -static Vector2 operator*(const Vector2 &v, float s) -{ +static Vector2 operator*(const Vector2 &v, float s) { return Vector2(v.x * s, v.y * s); } -static float dot(const Vector2 &a, const Vector2 &b) -{ +static float dot(const Vector2 &a, const Vector2 &b) { return a.x * b.x + a.y * b.y; } -static float lengthSquared(const Vector2 &v) -{ +static float lengthSquared(const Vector2 &v) { return v.x * v.x + v.y * v.y; } -static float length(const Vector2 &v) -{ +static float length(const Vector2 &v) { return sqrtf(lengthSquared(v)); } #if XA_DEBUG -static bool isNormalized(const Vector2 &v, float epsilon = kNormalEpsilon) -{ +static bool isNormalized(const Vector2 &v, float epsilon = kNormalEpsilon) { return equal(length(v), 1, epsilon); } #endif -static Vector2 normalize(const Vector2 &v, float epsilon) -{ +static Vector2 normalize(const Vector2 &v, float epsilon) { float l = length(v); XA_DEBUG_ASSERT(!isZero(l, epsilon)); XA_UNUSED(epsilon); @@ -640,36 +599,30 @@ static Vector2 normalize(const Vector2 &v, float epsilon) return n; } -static Vector2 normalizeSafe(const Vector2 &v, const Vector2 &fallback, float epsilon) -{ +static Vector2 normalizeSafe(const Vector2 &v, const Vector2 &fallback, float epsilon) { float l = length(v); if (isZero(l, epsilon)) return fallback; return v * (1.0f / l); } -static bool equal(const Vector2 &v1, const Vector2 &v2, float epsilon) -{ +static bool equal(const Vector2 &v1, const Vector2 &v2, float epsilon) { return equal(v1.x, v2.x, epsilon) && equal(v1.y, v2.y, epsilon); } -static Vector2 min(const Vector2 &a, const Vector2 &b) -{ +static Vector2 min(const Vector2 &a, const Vector2 &b) { return Vector2(min(a.x, b.x), min(a.y, b.y)); } -static Vector2 max(const Vector2 &a, const Vector2 &b) -{ +static Vector2 max(const Vector2 &a, const Vector2 &b) { return Vector2(max(a.x, b.x), max(a.y, b.y)); } -static bool isFinite(const Vector2 &v) -{ +static bool isFinite(const Vector2 &v) { return isFinite(v.x) && isFinite(v.y); } -static float triangleArea(const Vector2 &a, const Vector2 &b, const Vector2 &c) -{ +static float triangleArea(const Vector2 &a, const Vector2 &b, const Vector2 &c) { // IC: While it may be appealing to use the following expression: //return (c.x * a.y + a.x * b.y + b.x * c.y - b.x * a.y - c.x * b.y - a.x * c.y) * 0.5f; // That's actually a terrible idea. Small triangles far from the origin can end up producing fairly large floating point @@ -683,8 +636,7 @@ static float triangleArea(const Vector2 &a, const Vector2 &b, const Vector2 &c) return (v0.x * v1.y - v0.y * v1.x) * 0.5f; } -static bool linesIntersect(const Vector2 &a1, const Vector2 &a2, const Vector2 &b1, const Vector2 &b2, float epsilon) -{ +static bool linesIntersect(const Vector2 &a1, const Vector2 &a2, const Vector2 &b1, const Vector2 &b2, float epsilon) { const Vector2 v0 = a2 - a1; const Vector2 v1 = b2 - b1; const float denom = -v1.x * v0.y + v0.x * v1.y; @@ -692,76 +644,70 @@ static bool linesIntersect(const Vector2 &a1, const Vector2 &a2, const Vector2 & return false; const float s = (-v0.y * (a1.x - b1.x) + v0.x * (a1.y - b1.y)) / denom; if (s > epsilon && s < 1.0f - epsilon) { - const float t = ( v1.x * (a1.y - b1.y) - v1.y * (a1.x - b1.x)) / denom; + const float t = (v1.x * (a1.y - b1.y) - v1.y * (a1.x - b1.x)) / denom; return t > epsilon && t < 1.0f - epsilon; } return false; } -struct Vector2i -{ +struct Vector2i { Vector2i() {} - Vector2i(int32_t x, int32_t y) : x(x), y(y) {} + Vector2i(int32_t x, int32_t y) : + x(x), y(y) {} int32_t x, y; }; -class Vector3 -{ +class Vector3 { public: Vector3() {} - explicit Vector3(float f) : x(f), y(f), z(f) {} - Vector3(float x, float y, float z) : x(x), y(y), z(z) {} - Vector3(const Vector2 &v, float z) : x(v.x), y(v.y), z(z) {} + explicit Vector3(float f) : + x(f), y(f), z(f) {} + Vector3(float x, float y, float z) : + x(x), y(y), z(z) {} + Vector3(const Vector2 &v, float z) : + x(v.x), y(v.y), z(z) {} - Vector2 xy() const - { + Vector2 xy() const { return Vector2(x, y); } - Vector3 operator-() const - { + Vector3 operator-() const { return Vector3(-x, -y, -z); } - void operator+=(const Vector3 &v) - { + void operator+=(const Vector3 &v) { x += v.x; y += v.y; z += v.z; } - void operator-=(const Vector3 &v) - { + void operator-=(const Vector3 &v) { x -= v.x; y -= v.y; z -= v.z; } - void operator*=(float s) - { + void operator*=(float s) { x *= s; y *= s; z *= s; } - void operator/=(float s) - { + void operator/=(float s) { float is = 1.0f / s; x *= is; y *= is; z *= is; } - void operator*=(const Vector3 &v) - { + void operator*=(const Vector3 &v) { x *= v.x; y *= v.y; z *= v.z; } - void operator/=(const Vector3 &v) - { + void operator/=(const Vector3 &v) { x /= v.x; y /= v.y; z /= v.z; @@ -770,53 +716,43 @@ public: float x, y, z; }; -static Vector3 operator+(const Vector3 &a, const Vector3 &b) -{ +static Vector3 operator+(const Vector3 &a, const Vector3 &b) { return Vector3(a.x + b.x, a.y + b.y, a.z + b.z); } -static Vector3 operator-(const Vector3 &a, const Vector3 &b) -{ +static Vector3 operator-(const Vector3 &a, const Vector3 &b) { return Vector3(a.x - b.x, a.y - b.y, a.z - b.z); } -static Vector3 cross(const Vector3 &a, const Vector3 &b) -{ +static Vector3 cross(const Vector3 &a, const Vector3 &b) { return Vector3(a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x); } -static Vector3 operator*(const Vector3 &v, float s) -{ +static Vector3 operator*(const Vector3 &v, float s) { return Vector3(v.x * s, v.y * s, v.z * s); } -static Vector3 operator/(const Vector3 &v, float s) -{ +static Vector3 operator/(const Vector3 &v, float s) { return v * (1.0f / s); } -static float dot(const Vector3 &a, const Vector3 &b) -{ +static float dot(const Vector3 &a, const Vector3 &b) { return a.x * b.x + a.y * b.y + a.z * b.z; } -static float lengthSquared(const Vector3 &v) -{ +static float lengthSquared(const Vector3 &v) { return v.x * v.x + v.y * v.y + v.z * v.z; } -static float length(const Vector3 &v) -{ +static float length(const Vector3 &v) { return sqrtf(lengthSquared(v)); } -static bool isNormalized(const Vector3 &v, float epsilon = kNormalEpsilon) -{ +static bool isNormalized(const Vector3 &v, float epsilon = kNormalEpsilon) { return equal(length(v), 1, epsilon); } -static Vector3 normalize(const Vector3 &v, float epsilon) -{ +static Vector3 normalize(const Vector3 &v, float epsilon) { float l = length(v); XA_DEBUG_ASSERT(!isZero(l, epsilon)); XA_UNUSED(epsilon); @@ -825,8 +761,7 @@ static Vector3 normalize(const Vector3 &v, float epsilon) return n; } -static Vector3 normalizeSafe(const Vector3 &v, const Vector3 &fallback, float epsilon) -{ +static Vector3 normalizeSafe(const Vector3 &v, const Vector3 &fallback, float epsilon) { float l = length(v); if (isZero(l, epsilon)) { return fallback; @@ -834,72 +769,59 @@ static Vector3 normalizeSafe(const Vector3 &v, const Vector3 &fallback, float ep return v * (1.0f / l); } -static bool equal(const Vector3 &v0, const Vector3 &v1, float epsilon) -{ +static bool equal(const Vector3 &v0, const Vector3 &v1, float epsilon) { return fabs(v0.x - v1.x) <= epsilon && fabs(v0.y - v1.y) <= epsilon && fabs(v0.z - v1.z) <= epsilon; } -static Vector3 min(const Vector3 &a, const Vector3 &b) -{ +static Vector3 min(const Vector3 &a, const Vector3 &b) { return Vector3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); } -static Vector3 max(const Vector3 &a, const Vector3 &b) -{ +static Vector3 max(const Vector3 &a, const Vector3 &b) { return Vector3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); } #if XA_DEBUG -bool isFinite(const Vector3 &v) -{ +bool isFinite(const Vector3 &v) { return isFinite(v.x) && isFinite(v.y) && isFinite(v.z); } #endif -struct Extents2 -{ +struct Extents2 { Vector2 min, max; - void reset() - { + void reset() { min.x = min.y = FLT_MAX; max.x = max.y = -FLT_MAX; } - void add(Vector2 p) - { + void add(Vector2 p) { min = xatlas::internal::min(min, p); max = xatlas::internal::max(max, p); } - Vector2 midpoint() const - { + Vector2 midpoint() const { return Vector2(min.x + (max.x - min.x) * 0.5f, min.y + (max.y - min.y) * 0.5f); } - static bool intersect(Extents2 e1, Extents2 e2) - { + static bool intersect(Extents2 e1, Extents2 e2) { return e1.min.x <= e2.max.x && e1.max.x >= e2.min.x && e1.min.y <= e2.max.y && e1.max.y >= e2.min.y; } }; -struct Plane -{ +struct Plane { Plane() = default; - - Plane(const Vector3 &p1, const Vector3 &p2, const Vector3 &p3) - { + + Plane(const Vector3 &p1, const Vector3 &p2, const Vector3 &p3) { normal = cross(p2 - p1, p3 - p1); dist = dot(normal, p1); } - float distance(const Vector3 &p) const - { + float distance(const Vector3 &p) const { return dot(normal, p) - dist; } - void normalize() - { + void normalize() { const float len = length(normal); if (len > 0.0f) { const float il = 1.0f / len; @@ -912,8 +834,7 @@ struct Plane float dist; }; -static bool lineIntersectsPoint(const Vector3 &point, const Vector3 &lineStart, const Vector3 &lineEnd, float *t, float epsilon) -{ +static bool lineIntersectsPoint(const Vector3 &point, const Vector3 &lineStart, const Vector3 &lineEnd, float *t, float epsilon) { float tt; if (!t) t = &tt; @@ -930,22 +851,19 @@ static bool lineIntersectsPoint(const Vector3 &point, const Vector3 &lineStart, return *t > kEpsilon && *t < 1.0f - kEpsilon; } -static bool sameSide(const Vector3 &p1, const Vector3 &p2, const Vector3 &a, const Vector3 &b) -{ +static bool sameSide(const Vector3 &p1, const Vector3 &p2, const Vector3 &a, const Vector3 &b) { const Vector3 &ab = b - a; return dot(cross(ab, p1 - a), cross(ab, p2 - a)) >= 0.0f; } // http://blackpawn.com/texts/pointinpoly/default.html -static bool pointInTriangle(const Vector3 &p, const Vector3 &a, const Vector3 &b, const Vector3 &c) -{ +static bool pointInTriangle(const Vector3 &p, const Vector3 &a, const Vector3 &b, const Vector3 &c) { return sameSide(p, a, b, c) && sameSide(p, b, a, c) && sameSide(p, c, a, b); } #if XA_CLOSE_HOLES_CHECK_EDGE_INTERSECTION // https://en.wikipedia.org/wiki/M%C3%B6ller%E2%80%93Trumbore_intersection_algorithm -static bool rayIntersectsTriangle(const Vector3 &rayOrigin, const Vector3 &rayDir, const Vector3 *tri, float *t) -{ +static bool rayIntersectsTriangle(const Vector3 &rayOrigin, const Vector3 &rayDir, const Vector3 *tri, float *t) { *t = 0.0f; const Vector3 &edge1 = tri[1] - tri[0]; const Vector3 &edge2 = tri[2] - tri[0]; @@ -972,50 +890,47 @@ static bool rayIntersectsTriangle(const Vector3 &rayOrigin, const Vector3 &rayDi #endif // From Fast-BVH -struct AABB -{ - AABB() : min(FLT_MAX, FLT_MAX, FLT_MAX), max(-FLT_MAX, -FLT_MAX, -FLT_MAX) {} - AABB(const Vector3 &min, const Vector3 &max) : min(min), max(max) { } - AABB(const Vector3 &p, float radius = 0.0f) : min(p), max(p) { if (radius > 0.0f) expand(radius); } +struct AABB { + AABB() : + min(FLT_MAX, FLT_MAX, FLT_MAX), max(-FLT_MAX, -FLT_MAX, -FLT_MAX) {} + AABB(const Vector3 &min, const Vector3 &max) : + min(min), max(max) {} + AABB(const Vector3 &p, float radius = 0.0f) : + min(p), max(p) { + if (radius > 0.0f) expand(radius); + } - bool intersect(const AABB &other) const - { + bool intersect(const AABB &other) const { return min.x <= other.max.x && max.x >= other.min.x && min.y <= other.max.y && max.y >= other.min.y && min.z <= other.max.z && max.z >= other.min.z; } - void expandToInclude(const Vector3 &p) - { + void expandToInclude(const Vector3 &p) { min = internal::min(min, p); max = internal::max(max, p); } - void expandToInclude(const AABB &aabb) - { + void expandToInclude(const AABB &aabb) { min = internal::min(min, aabb.min); max = internal::max(max, aabb.max); } - void expand(float amount) - { + void expand(float amount) { min -= Vector3(amount); max += Vector3(amount); } - Vector3 centroid() const - { + Vector3 centroid() const { return min + (max - min) * 0.5f; } - uint32_t maxDimension() const - { + uint32_t maxDimension() const { const Vector3 extent = max - min; uint32_t result = 0; if (extent.y > extent.x) { result = 1; if (extent.z > extent.y) result = 2; - } - else if(extent.z > extent.x) + } else if (extent.z > extent.x) result = 2; return result; } @@ -1023,10 +938,9 @@ struct AABB Vector3 min, max; }; -struct ArrayBase -{ - ArrayBase(uint32_t elementSize, int memTag = MemTag::Default) : buffer(nullptr), elementSize(elementSize), size(0), capacity(0) - { +struct ArrayBase { + ArrayBase(uint32_t elementSize, int memTag = MemTag::Default) : + buffer(nullptr), elementSize(elementSize), size(0), capacity(0) { #if XA_DEBUG_HEAP this->memTag = memTag; #else @@ -1034,31 +948,26 @@ struct ArrayBase #endif } - ~ArrayBase() - { + ~ArrayBase() { XA_FREE(buffer); } - XA_INLINE void clear() - { + XA_INLINE void clear() { size = 0; } - void copyFrom(const uint8_t *data, uint32_t length) - { + void copyFrom(const uint8_t *data, uint32_t length) { resize(length, true); memcpy(buffer, data, length * elementSize); } - void copyTo(ArrayBase &other) const - { + void copyTo(ArrayBase &other) const { XA_DEBUG_ASSERT(elementSize == other.elementSize); other.resize(size, true); memcpy(other.buffer, buffer, size * elementSize); } - void destroy() - { + void destroy() { size = 0; XA_FREE(buffer); buffer = nullptr; @@ -1067,8 +976,7 @@ struct ArrayBase } // Insert the given element at the given index shifting all the elements up. - void insertAt(uint32_t index, const uint8_t *value) - { + void insertAt(uint32_t index, const uint8_t *value) { XA_DEBUG_ASSERT(index >= 0 && index <= size); resize(size + 1, false); if (index < size - 1) @@ -1076,8 +984,7 @@ struct ArrayBase memcpy(&buffer[index * elementSize], value, elementSize); } - void moveTo(ArrayBase &other) - { + void moveTo(ArrayBase &other) { XA_DEBUG_ASSERT(elementSize == other.elementSize); other.destroy(); other.buffer = buffer; @@ -1091,21 +998,18 @@ struct ArrayBase elementSize = size = capacity = 0; } - void pop_back() - { + void pop_back() { XA_DEBUG_ASSERT(size > 0); resize(size - 1, false); } - void push_back(const uint8_t *value) - { + void push_back(const uint8_t *value) { XA_DEBUG_ASSERT(value < buffer || value >= buffer + size); resize(size + 1, false); memcpy(&buffer[(size - 1) * elementSize], value, elementSize); } - void push_back(const ArrayBase &other) - { + void push_back(const ArrayBase &other) { XA_DEBUG_ASSERT(elementSize == other.elementSize); if (other.size == 0) return; @@ -1115,22 +1019,19 @@ struct ArrayBase } // Remove the element at the given index. This is an expensive operation! - void removeAt(uint32_t index) - { + void removeAt(uint32_t index) { XA_DEBUG_ASSERT(index >= 0 && index < size); if (size != 1) memmove(buffer + elementSize * index, buffer + elementSize * (index + 1), elementSize * (size - 1 - index)); size--; } - void reserve(uint32_t desiredSize) - { + void reserve(uint32_t desiredSize) { if (desiredSize > capacity) setArrayCapacity(desiredSize); } - void resize(uint32_t newSize, bool exact) - { + void resize(uint32_t newSize, bool exact) { size = newSize; if (size > capacity) { // First allocation is always exact. Otherwise, following allocations grow array to 150% of desired size. @@ -1143,8 +1044,7 @@ struct ArrayBase } } - void setArrayCapacity(uint32_t newCapacity) - { + void setArrayCapacity(uint32_t newCapacity) { XA_DEBUG_ASSERT(newCapacity >= size); if (newCapacity == 0) { // free the buffer. @@ -1164,8 +1064,7 @@ struct ArrayBase } #if XA_DEBUG_HEAP - void setMemTag(int memTag) - { + void setMemTag(int memTag) { this->memTag = memTag; } #endif @@ -1179,28 +1078,25 @@ struct ArrayBase #endif }; -template -class Array -{ +template +class Array { public: - Array(int memTag = MemTag::Default) : m_base(sizeof(T), memTag) {} - Array(const Array&) = delete; + Array(int memTag = MemTag::Default) : + m_base(sizeof(T), memTag) {} + Array(const Array &) = delete; Array &operator=(const Array &) = delete; - XA_INLINE const T &operator[](uint32_t index) const - { + XA_INLINE const T &operator[](uint32_t index) const { XA_DEBUG_ASSERT(index < m_base.size); return ((const T *)m_base.buffer)[index]; } - XA_INLINE T &operator[](uint32_t index) - { + XA_INLINE T &operator[](uint32_t index) { XA_DEBUG_ASSERT(index < m_base.size); return ((T *)m_base.buffer)[index]; } - XA_INLINE const T &back() const - { + XA_INLINE const T &back() const { XA_DEBUG_ASSERT(!isEmpty()); return ((const T *)m_base.buffer)[m_base.size - 1]; } @@ -1208,8 +1104,7 @@ public: XA_INLINE T *begin() { return (T *)m_base.buffer; } XA_INLINE void clear() { m_base.clear(); } - bool contains(const T &value) const - { + bool contains(const T &value) const { for (uint32_t i = 0; i < m_base.size; i++) { if (((const T *)m_base.buffer)[i] == value) return true; @@ -1232,20 +1127,17 @@ public: void reserve(uint32_t desiredSize) { m_base.reserve(desiredSize); } void resize(uint32_t newSize) { m_base.resize(newSize, true); } - void runCtors() - { + void runCtors() { for (uint32_t i = 0; i < m_base.size; i++) new (&((T *)m_base.buffer)[i]) T; } - void runDtors() - { + void runDtors() { for (uint32_t i = 0; i < m_base.size; i++) ((T *)m_base.buffer)[i].~T(); } - void setAll(const T &value) - { + void setAll(const T &value) { auto buffer = (T *)m_base.buffer; for (uint32_t i = 0; i < m_base.size; i++) buffer[i] = value; @@ -1262,33 +1154,47 @@ private: ArrayBase m_base; }; -template -struct ArrayView -{ - ArrayView(Array &a) : data(a.data()), length(a.size()) {} - ArrayView(T *data, uint32_t length) : data(data), length(length) {} - ArrayView &operator=(Array &a) { data = a.data(); length = a.size(); return *this; } - XA_INLINE const T &operator[](uint32_t index) const { XA_DEBUG_ASSERT(index < length); return data[index]; } +template +struct ArrayView { + ArrayView(Array &a) : + data(a.data()), length(a.size()) {} + ArrayView(T *data, uint32_t length) : + data(data), length(length) {} + ArrayView &operator=(Array &a) { + data = a.data(); + length = a.size(); + return *this; + } + XA_INLINE const T &operator[](uint32_t index) const { + XA_DEBUG_ASSERT(index < length); + return data[index]; + } T *data; uint32_t length; }; -template -struct ConstArrayView -{ - ConstArrayView(const Array &a) : data(a.data()), length(a.size()) {} - ConstArrayView(const T *data, uint32_t length) : data(data), length(length) {} - ConstArrayView &operator=(const Array &a) { data = a.data(); length = a.size(); return *this; } - XA_INLINE const T &operator[](uint32_t index) const { XA_DEBUG_ASSERT(index < length); return data[index]; } +template +struct ConstArrayView { + ConstArrayView(const Array &a) : + data(a.data()), length(a.size()) {} + ConstArrayView(const T *data, uint32_t length) : + data(data), length(length) {} + ConstArrayView &operator=(const Array &a) { + data = a.data(); + length = a.size(); + return *this; + } + XA_INLINE const T &operator[](uint32_t index) const { + XA_DEBUG_ASSERT(index < length); + return data[index]; + } const T *data; uint32_t length; }; /// Basis class to compute tangent space basis, ortogonalizations and to transform vectors from one space to another. -struct Basis -{ - XA_NODISCARD static Vector3 computeTangent(const Vector3 &normal) - { +struct Basis { + XA_NODISCARD static Vector3 computeTangent(const Vector3 &normal) { XA_ASSERT(isNormalized(normal)); // Choose minimum axis. Vector3 tangent; @@ -1304,8 +1210,7 @@ struct Basis return tangent; } - XA_NODISCARD static Vector3 computeBitangent(const Vector3 &normal, const Vector3 &tangent) - { + XA_NODISCARD static Vector3 computeBitangent(const Vector3 &normal, const Vector3 &tangent) { return cross(normal, tangent); } @@ -1315,36 +1220,31 @@ struct Basis }; // Simple bit array. -class BitArray -{ +class BitArray { public: - BitArray() : m_size(0) {} + BitArray() : + m_size(0) {} - BitArray(uint32_t sz) - { + BitArray(uint32_t sz) { resize(sz); } - void resize(uint32_t new_size) - { + void resize(uint32_t new_size) { m_size = new_size; m_wordArray.resize((m_size + 31) >> 5); } - bool get(uint32_t index) const - { + bool get(uint32_t index) const { XA_DEBUG_ASSERT(index < m_size); return (m_wordArray[index >> 5] & (1 << (index & 31))) != 0; } - void set(uint32_t index) - { + void set(uint32_t index) { XA_DEBUG_ASSERT(index < m_size); - m_wordArray[index >> 5] |= (1 << (index & 31)); + m_wordArray[index >> 5] |= (1 << (index & 31)); } - void zeroOutMemory() - { + void zeroOutMemory() { m_wordArray.zeroOutMemory(); } @@ -1353,13 +1253,13 @@ private: Array m_wordArray; }; -class BitImage -{ +class BitImage { public: - BitImage() : m_width(0), m_height(0), m_rowStride(0), m_data(MemTag::BitImage) {} + BitImage() : + m_width(0), m_height(0), m_rowStride(0), m_data(MemTag::BitImage) {} - BitImage(uint32_t w, uint32_t h) : m_width(w), m_height(h), m_data(MemTag::BitImage) - { + BitImage(uint32_t w, uint32_t h) : + m_width(w), m_height(h), m_data(MemTag::BitImage) { m_rowStride = (m_width + 63) >> 6; m_data.resize(m_rowStride * m_height); m_data.zeroOutMemory(); @@ -1370,16 +1270,14 @@ public: uint32_t width() const { return m_width; } uint32_t height() const { return m_height; } - void copyTo(BitImage &other) - { + void copyTo(BitImage &other) { other.m_width = m_width; other.m_height = m_height; other.m_rowStride = m_rowStride; m_data.copyTo(other.m_data); } - void resize(uint32_t w, uint32_t h, bool discard) - { + void resize(uint32_t w, uint32_t h, bool discard) { const uint32_t rowStride = (w + 63) >> 6; if (discard) { m_data.resize(rowStride * h); @@ -1403,28 +1301,24 @@ public: m_rowStride = rowStride; } - bool get(uint32_t x, uint32_t y) const - { + bool get(uint32_t x, uint32_t y) const { XA_DEBUG_ASSERT(x < m_width && y < m_height); const uint32_t index = (x >> 6) + y * m_rowStride; return (m_data[index] & (UINT64_C(1) << (uint64_t(x) & UINT64_C(63)))) != 0; } - void set(uint32_t x, uint32_t y) - { + void set(uint32_t x, uint32_t y) { XA_DEBUG_ASSERT(x < m_width && y < m_height); const uint32_t index = (x >> 6) + y * m_rowStride; m_data[index] |= UINT64_C(1) << (uint64_t(x) & UINT64_C(63)); XA_DEBUG_ASSERT(get(x, y)); } - void zeroOutMemory() - { + void zeroOutMemory() { m_data.zeroOutMemory(); } - bool canBlit(const BitImage &image, uint32_t offsetX, uint32_t offsetY) const - { + bool canBlit(const BitImage &image, uint32_t offsetX, uint32_t offsetY) const { for (uint32_t y = 0; y < image.m_height; y++) { const uint32_t thisY = y + offsetY; if (thisY >= m_height) @@ -1448,8 +1342,7 @@ public: return true; } - void dilate(uint32_t padding) - { + void dilate(uint32_t padding) { BitImage tmp(m_width, m_height); for (uint32_t p = 0; p < padding; p++) { tmp.zeroOutMemory(); @@ -1486,11 +1379,10 @@ private: }; // From Fast-BVH -class BVH -{ +class BVH { public: - BVH(const Array &objectAabbs, uint32_t leafSize = 4) : m_objectIds(MemTag::BVH), m_nodes(MemTag::BVH) - { + BVH(const Array &objectAabbs, uint32_t leafSize = 4) : + m_objectIds(MemTag::BVH), m_nodes(MemTag::BVH) { m_objectAabbs = &objectAabbs; if (m_objectAabbs->isEmpty()) return; @@ -1510,7 +1402,7 @@ public: Node node; m_nodes.reserve(objectAabbs.size() * 2); uint32_t nNodes = 0; - while(stackptr > 0) { + while (stackptr > 0) { // Pop the next item off of the stack const BuildEntry &bnode = todo[--stackptr]; const uint32_t start = bnode.start; @@ -1523,7 +1415,7 @@ public: // Calculate the bounding box for this node AABB bb(objectAabbs[m_objectIds[start]]); AABB bc(objectAabbs[m_objectIds[start]].centroid()); - for(uint32_t p = start + 1; p < end; ++p) { + for (uint32_t p = start + 1; p < end; ++p) { bb.expandToInclude(objectAabbs[m_objectIds[p]]); bc.expandToInclude(objectAabbs[m_objectIds[p]].centroid()); } @@ -1539,7 +1431,7 @@ public: m_nodes[bnode.parent].rightOffset--; // When this is the second touch, this is the right child. // The right child sets up the offset for the flat tree. - if (m_nodes[bnode.parent].rightOffset == kTouchedTwice ) + if (m_nodes[bnode.parent].rightOffset == kTouchedTwice) m_nodes[bnode.parent].rightOffset = nNodes - 1 - bnode.parent; } // If this is a leaf, no need to subdivide. @@ -1574,21 +1466,20 @@ public: } } - void query(const AABB &queryAabb, Array &result) const - { + void query(const AABB &queryAabb, Array &result) const { result.clear(); // Working set uint32_t todo[64]; int32_t stackptr = 0; // "Push" on the root node to the working set todo[stackptr] = 0; - while(stackptr >= 0) { + while (stackptr >= 0) { // Pop off the next node to work on. const int ni = todo[stackptr--]; const Node &node = m_nodes[ni]; // Is leaf -> Intersect if (node.rightOffset == 0) { - for(uint32_t o = 0; o < node.nPrims; ++o) { + for (uint32_t o = 0; o < node.nPrims; ++o) { const uint32_t obj = node.start + o; if (queryAabb.intersect((*m_objectAabbs)[m_objectIds[obj]])) result.push_back(m_objectIds[obj]); @@ -1605,14 +1496,12 @@ public: } private: - struct BuildEntry - { + struct BuildEntry { uint32_t parent; // If non-zero then this is the index of the parent. (used in offsets) uint32_t start, end; // The range of objects in the object list covered by this node. }; - struct Node - { + struct Node { AABB aabb; uint32_t start, nPrims, rightOffset; }; @@ -1622,10 +1511,8 @@ private: Array m_nodes; }; -struct Fit -{ - static bool computeBasis(const Vector3 *points, uint32_t pointsCount, Basis *basis) - { +struct Fit { + static bool computeBasis(const Vector3 *points, uint32_t pointsCount, Basis *basis) { if (computeLeastSquaresNormal(points, pointsCount, &basis->normal)) { basis->tangent = Basis::computeTangent(basis->normal); basis->bitangent = Basis::computeBitangent(basis->normal, basis->tangent); @@ -1639,8 +1526,7 @@ private: // Fast, and accurate to within a few degrees. // Returns None if the points do not span a plane. // https://www.ilikebigbits.com/2015_03_04_plane_from_points.html - static bool computeLeastSquaresNormal(const Vector3 *points, uint32_t pointsCount, Vector3 *normal) - { + static bool computeLeastSquaresNormal(const Vector3 *points, uint32_t pointsCount, Vector3 *normal) { XA_DEBUG_ASSERT(pointsCount >= 3); if (pointsCount == 3) { *normal = normalize(cross(points[2] - points[0], points[1] - points[0]), kEpsilon); @@ -1705,7 +1591,7 @@ private: // Pick path with best conditioning: Vector3 dir(0.0f); if (det_max == det_x) - dir = Vector3(det_x,xz * yz - xy * zz,xy * yz - xz * yy); + dir = Vector3(det_x, xz * yz - xy * zz, xy * yz - xz * yy); else if (det_max == det_y) dir = Vector3(xz * yz - xy * zz, det_y, xy * xz - yz * xx); else if (det_max == det_z) @@ -1718,8 +1604,7 @@ private: return isNormalized(*normal); } - static bool computeEigen(const Vector3 *points, uint32_t pointsCount, Basis *basis) - { + static bool computeEigen(const Vector3 *points, uint32_t pointsCount, Basis *basis) { float matrix[6]; computeCovariance(pointsCount, points, matrix); if (matrix[0] == 0 && matrix[3] == 0 && matrix[5] == 0) @@ -1734,8 +1619,7 @@ private: return true; } - static Vector3 computeCentroid(int n, const Vector3 * points) - { + static Vector3 computeCentroid(int n, const Vector3 *points) { Vector3 centroid(0.0f); for (int i = 0; i < n; i++) { centroid += points[i]; @@ -1744,8 +1628,7 @@ private: return centroid; } - static Vector3 computeCovariance(int n, const Vector3 * points, float * covariance) - { + static Vector3 computeCovariance(int n, const Vector3 *points, float *covariance) { // compute the centroid Vector3 centroid = computeCentroid(n, points); // compute covariance matrix @@ -1767,8 +1650,7 @@ private: // Tridiagonal solver from Charles Bloom. // Householder transforms followed by QL decomposition. // Seems to be based on the code from Numerical Recipes in C. - static bool eigenSolveSymmetric3(const float matrix[6], float eigenValues[3], Vector3 eigenVectors[3]) - { + static bool eigenSolveSymmetric3(const float matrix[6], float eigenValues[3], Vector3 eigenVectors[3]) { XA_DEBUG_ASSERT(matrix != nullptr && eigenValues != nullptr && eigenVectors != nullptr); float subd[3]; float diag[3]; @@ -1793,7 +1675,7 @@ private: // eigenvectors are the columns; make them the rows : for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { - (&eigenVectors[j].x)[i] = (float) work[i][j]; + (&eigenVectors[j].x)[i] = (float)work[i][j]; } } // shuffle to sort by singular value : @@ -1815,8 +1697,7 @@ private: } private: - static void EigenSolver3_Tridiagonal(float mat[3][3], float *diag, float *subd) - { + static void EigenSolver3_Tridiagonal(float mat[3][3], float *diag, float *subd) { // Householder reduction T = Q^t M Q // Input: // mat, symmetric 3x3 matrix M @@ -1868,8 +1749,7 @@ private: } } - static bool EigenSolver3_QLAlgorithm(float mat[3][3], float *diag, float *subd) - { + static bool EigenSolver3_QLAlgorithm(float mat[3][3], float *diag, float *subd) { // QL iteration with implicit shifting to reduce matrix from tridiagonal // to diagonal const int maxiter = 32; @@ -1879,21 +1759,21 @@ private: int m; for (m = ell; m <= 1; m++) { float dd = fabsf(diag[m]) + fabsf(diag[m + 1]); - if ( fabsf(subd[m]) + dd == dd ) + if (fabsf(subd[m]) + dd == dd) break; } - if ( m == ell ) + if (m == ell) break; float g = (diag[ell + 1] - diag[ell]) / (2 * subd[ell]); float r = sqrtf(g * g + 1); - if ( g < 0 ) + if (g < 0) g = diag[m] - diag[ell] + subd[ell] / (g - r); else g = diag[m] - diag[ell] + subd[ell] / (g + r); float s = 1, c = 1, p = 0; for (int i = m - 1; i >= ell; i--) { float f = s * subd[i], b = c * subd[i]; - if ( fabsf(f) >= fabsf(g) ) { + if (fabsf(f) >= fabsf(g)) { c = g / f; r = sqrtf(c * c + 1); subd[i + 1] = f * r; @@ -1919,7 +1799,7 @@ private: subd[ell] = g; subd[m] = 0; } - if ( iter == maxiter ) + if (iter == maxiter) // should not get here under normal circumstances return false; } @@ -1928,18 +1808,18 @@ private: }; /// Fixed size vector class. -class FullVector -{ +class FullVector { public: - FullVector(uint32_t dim) : m_array(MemTag::FullVector) { m_array.resize(dim); } - FullVector(const FullVector &v) : m_array(MemTag::FullVector) { v.m_array.copyTo(m_array); } + FullVector(uint32_t dim) : + m_array(MemTag::FullVector) { m_array.resize(dim); } + FullVector(const FullVector &v) : + m_array(MemTag::FullVector) { v.m_array.copyTo(m_array); } FullVector &operator=(const FullVector &v) = delete; XA_INLINE uint32_t dimension() const { return m_array.size(); } XA_INLINE const float &operator[](uint32_t index) const { return m_array[index]; } XA_INLINE float &operator[](uint32_t index) { return m_array[index]; } - void fill(float f) - { + void fill(float f) { const uint32_t dim = dimension(); for (uint32_t i = 0; i < dim; i++) m_array[i] = f; @@ -1949,22 +1829,19 @@ private: Array m_array; }; -template, typename E = Equal > -class HashMap -{ +template , typename E = Equal> +class HashMap { public: - HashMap(int memTag, uint32_t size) : m_memTag(memTag), m_size(size), m_numSlots(0), m_slots(nullptr), m_keys(memTag), m_next(memTag) - { + HashMap(int memTag, uint32_t size) : + m_memTag(memTag), m_size(size), m_numSlots(0), m_slots(nullptr), m_keys(memTag), m_next(memTag) { } - ~HashMap() - { + ~HashMap() { if (m_slots) XA_FREE(m_slots); } - void add(const Key &key) - { + void add(const Key &key) { if (!m_slots) alloc(); const uint32_t hash = computeHash(key); @@ -1973,8 +1850,7 @@ public: m_slots[hash] = m_next.size() - 1; } - uint32_t get(const Key &key) const - { + uint32_t get(const Key &key) const { if (!m_slots) return UINT32_MAX; const uint32_t hash = computeHash(key); @@ -1988,8 +1864,7 @@ public: return UINT32_MAX; } - uint32_t getNext(uint32_t current) const - { + uint32_t getNext(uint32_t current) const { uint32_t i = m_next[current]; E equal; while (i != UINT32_MAX) { @@ -2001,8 +1876,7 @@ public: } private: - void alloc() - { + void alloc() { XA_DEBUG_ASSERT(m_size > 0); m_numSlots = nextPowerOfTwo(m_size); auto minNumSlots = uint32_t(m_size * 1.3); @@ -2015,8 +1889,7 @@ private: m_next.reserve(m_size); } - uint32_t computeHash(const Key &key) const - { + uint32_t computeHash(const Key &key) const { H hash; return hash(key) & (m_numSlots - 1); } @@ -2029,9 +1902,8 @@ private: Array m_next; }; -template -static void insertionSort(T *data, uint32_t length) -{ +template +static void insertionSort(T *data, uint32_t length) { for (int32_t i = 1; i < (int32_t)length; i++) { T x = data[i]; int32_t j = i - 1; @@ -2043,21 +1915,18 @@ static void insertionSort(T *data, uint32_t length) } } -class KISSRng -{ +class KISSRng { public: KISSRng() { reset(); } - void reset() - { + void reset() { x = 123456789; y = 362436000; z = 521288629; c = 7654321; } - uint32_t getRange(uint32_t range) - { + uint32_t getRange(uint32_t range) { if (range == 0) return 0; x = 69069 * x + 12345; @@ -2076,20 +1945,18 @@ private: // Based on Pierre Terdiman's and Michael Herf's source code. // http://www.codercorner.com/RadixSortRevisited.htm // http://www.stereopsis.com/radix.html -class RadixSort -{ +class RadixSort { public: - RadixSort() : m_size(0), m_ranks(nullptr), m_ranks2(nullptr), m_validRanks(false) {} + RadixSort() : + m_size(0), m_ranks(nullptr), m_ranks2(nullptr), m_validRanks(false) {} - ~RadixSort() - { + ~RadixSort() { // Release everything XA_FREE(m_ranks2); XA_FREE(m_ranks); } - RadixSort &sort(const float *input, uint32_t count) - { + RadixSort &sort(const float *input, uint32_t count) { if (input == nullptr || count == 0) return *this; // Resize lists if needed if (count != m_size) { @@ -2115,20 +1982,17 @@ public: return *this; } - RadixSort &sort(const Array &input) - { + RadixSort &sort(const Array &input) { return sort(input.data(), input.size()); } // Access to results. m_ranks is a list of indices in sorted order, i.e. in the order you may further process your data - const uint32_t *ranks() const - { + const uint32_t *ranks() const { XA_DEBUG_ASSERT(m_validRanks); return m_ranks; } - uint32_t *ranks() - { + uint32_t *ranks() { XA_DEBUG_ASSERT(m_validRanks); return m_ranks; } @@ -2139,21 +2003,18 @@ private: uint32_t *m_ranks2; bool m_validRanks; - void FloatFlip(uint32_t &f) - { + void FloatFlip(uint32_t &f) { int32_t mask = (int32_t(f) >> 31) | 0x80000000; // Warren Hunt, Manchor Ko. f ^= mask; } - void IFloatFlip(uint32_t &f) - { + void IFloatFlip(uint32_t &f) { uint32_t mask = ((f >> 31) - 1) | 0x80000000; // Michael Herf. f ^= mask; } - template - void createHistograms(const T *buffer, uint32_t count, uint32_t *histogram) - { + template + void createHistograms(const T *buffer, uint32_t count, uint32_t *histogram) { const uint32_t bucketCount = sizeof(T); // (8 * sizeof(T)) / log2(radix) // Init bucket pointers. uint32_t *h[bucketCount]; @@ -2161,10 +2022,10 @@ private: h[i] = histogram + 256 * i; } // Clear histograms. - memset(histogram, 0, 256 * bucketCount * sizeof(uint32_t )); + memset(histogram, 0, 256 * bucketCount * sizeof(uint32_t)); // @@ Add support for signed integers. // Build histograms. - const uint8_t *p = (const uint8_t *)buffer; // @@ Does this break aliasing rules? + const uint8_t *p = (const uint8_t *)buffer; // @@ Does this break aliasing rules? const uint8_t *pe = p + count * sizeof(T); while (p != pe) { h[0][*p++]++, h[1][*p++]++, h[2][*p++]++, h[3][*p++]++; @@ -2179,8 +2040,8 @@ private: } } - template void insertionSort(const T *input, uint32_t count) - { + template + void insertionSort(const T *input, uint32_t count) { if (!m_validRanks) { m_ranks[0] = 0; for (uint32_t i = 1; i != count; ++i) { @@ -2210,8 +2071,8 @@ private: } } - template void radixSort(const T *input, uint32_t count) - { + template + void radixSort(const T *input, uint32_t count) { const uint32_t P = sizeof(T); // pass count // Allocate histograms & offsets on the stack uint32_t histogram[256 * P]; @@ -2229,7 +2090,8 @@ private: } // Create offsets link[0] = m_ranks2; - for (uint32_t i = 1; i < 256; i++) link[i] = link[i - 1] + h[i - 1]; + for (uint32_t i = 1; i < 256; i++) + link[i] = link[i - 1] + h[i - 1]; // Perform Radix Sort if (!m_validRanks) { for (uint32_t i = 0; i < count; i++) { @@ -2256,25 +2118,21 @@ private: }; // Wrapping this in a class allows temporary arrays to be re-used. -class BoundingBox2D -{ +class BoundingBox2D { public: Vector2 majorAxis, minorAxis, minCorner, maxCorner; - void clear() - { + void clear() { m_boundaryVertices.clear(); } - void appendBoundaryVertex(Vector2 v) - { + void appendBoundaryVertex(Vector2 v) { m_boundaryVertices.push_back(v); } // This should compute convex hull and use rotating calipers to find the best box. Currently it uses a brute force method. // If vertices is null or vertexCount is 0, the boundary vertices are used. - void compute(const Vector2 *vertices = nullptr, uint32_t vertexCount = 0) - { + void compute(const Vector2 *vertices = nullptr, uint32_t vertexCount = 0) { if (!vertices || vertexCount == 0) { vertices = m_boundaryVertices.data(); vertexCount = m_boundaryVertices.size(); @@ -2322,8 +2180,7 @@ public: private: // Compute the convex hull using Graham Scan. - void convexHull(const Vector2 *input, uint32_t inputCount, Array &output, float epsilon) - { + void convexHull(const Vector2 *input, uint32_t inputCount, Array &output, float epsilon) { m_coords.resize(inputCount); for (uint32_t i = 0; i < inputCount; i++) m_coords[i] = input[i].x; @@ -2353,7 +2210,7 @@ private: XA_DEBUG_ASSERT(m_top.size() >= 2); output.push_back(m_top[0]); output.push_back(m_top[1]); - for (uint32_t i = 2; i < m_top.size(); ) { + for (uint32_t i = 2; i < m_top.size();) { Vector2 a = output[output.size() - 2]; Vector2 b = output[output.size() - 1]; Vector2 c = m_top[i]; @@ -2369,7 +2226,7 @@ private: XA_DEBUG_ASSERT(m_bottom.size() >= 2); output.push_back(m_bottom[1]); // Filter bottom list. - for (uint32_t i = 2; i < m_bottom.size(); ) { + for (uint32_t i = 2; i < m_bottom.size();) { Vector2 a = output[output.size() - 2]; Vector2 b = output[output.size() - 1]; Vector2 c = m_bottom[i]; @@ -2391,33 +2248,33 @@ private: Array m_top, m_bottom, m_hull; }; -static uint32_t meshEdgeFace(uint32_t edge) { return edge / 3; } -static uint32_t meshEdgeIndex0(uint32_t edge) { return edge; } +static uint32_t meshEdgeFace(uint32_t edge) { + return edge / 3; +} +static uint32_t meshEdgeIndex0(uint32_t edge) { + return edge; +} -static uint32_t meshEdgeIndex1(uint32_t edge) -{ +static uint32_t meshEdgeIndex1(uint32_t edge) { const uint32_t faceFirstEdge = edge / 3 * 3; return faceFirstEdge + (edge - faceFirstEdge + 1) % 3; } -struct MeshFlags -{ - enum - { - HasFaceGroups = 1<<0, - HasIgnoredFaces = 1<<1, - HasNormals = 1<<2 +struct MeshFlags { + enum { + HasFaceGroups = 1 << 0, + HasIgnoredFaces = 1 << 1, + HasNormals = 1 << 2 }; }; class Mesh; static void meshGetBoundaryLoops(const Mesh &mesh, Array &boundaryLoops); -class Mesh -{ +class Mesh { public: - Mesh(float epsilon, uint32_t approxVertexCount, uint32_t approxFaceCount, uint32_t flags = 0, uint32_t id = UINT32_MAX) : m_epsilon(epsilon), m_flags(flags), m_id(id), m_faceIgnore(MemTag::Mesh), m_ignoredFaceCount(0), m_indices(MemTag::MeshIndices), m_positions(MemTag::MeshPositions), m_normals(MemTag::MeshNormals), m_texcoords(MemTag::MeshTexcoords), m_faceGroups(MemTag::Mesh), m_faceGroupFirstFace(MemTag::Mesh), m_faceGroupNextFace(MemTag::Mesh), m_faceGroupFaceCounts(MemTag::Mesh), m_colocalVertexCount(0), m_nextColocalVertex(MemTag::MeshColocals), m_boundaryEdges(MemTag::MeshBoundaries), m_oppositeEdges(MemTag::MeshBoundaries), m_nextBoundaryEdges(MemTag::MeshBoundaries), m_edgeMap(MemTag::MeshEdgeMap, approxFaceCount * 3) - { + Mesh(float epsilon, uint32_t approxVertexCount, uint32_t approxFaceCount, uint32_t flags = 0, uint32_t id = UINT32_MAX) : + m_epsilon(epsilon), m_flags(flags), m_id(id), m_faceIgnore(MemTag::Mesh), m_ignoredFaceCount(0), m_indices(MemTag::MeshIndices), m_positions(MemTag::MeshPositions), m_normals(MemTag::MeshNormals), m_texcoords(MemTag::MeshTexcoords), m_faceGroups(MemTag::Mesh), m_faceGroupFirstFace(MemTag::Mesh), m_faceGroupNextFace(MemTag::Mesh), m_faceGroupFaceCounts(MemTag::Mesh), m_colocalVertexCount(0), m_nextColocalVertex(MemTag::MeshColocals), m_boundaryEdges(MemTag::MeshBoundaries), m_oppositeEdges(MemTag::MeshBoundaries), m_nextBoundaryEdges(MemTag::MeshBoundaries), m_edgeMap(MemTag::MeshEdgeMap, approxFaceCount * 3) { m_indices.reserve(approxFaceCount * 3); m_positions.reserve(approxVertexCount); m_texcoords.reserve(approxVertexCount); @@ -2433,8 +2290,7 @@ public: uint32_t flags() const { return m_flags; } uint32_t id() const { return m_id; } - void addVertex(const Vector3 &pos, const Vector3 &normal = Vector3(0.0f), const Vector2 &texcoord = Vector2(0.0f)) - { + void addVertex(const Vector3 &pos, const Vector3 &normal = Vector3(0.0f), const Vector2 &texcoord = Vector2(0.0f)) { XA_DEBUG_ASSERT(isFinite(pos)); m_positions.push_back(pos); if (m_flags & MeshFlags::HasNormals) @@ -2442,17 +2298,14 @@ public: m_texcoords.push_back(texcoord); } - struct AddFaceResult - { - enum Enum - { + struct AddFaceResult { + enum Enum { OK, DuplicateEdge = 1 }; }; - AddFaceResult::Enum addFace(uint32_t v0, uint32_t v1, uint32_t v2, bool ignore = false, bool hashEdge = true) - { + AddFaceResult::Enum addFace(uint32_t v0, uint32_t v1, uint32_t v2, bool ignore = false, bool hashEdge = true) { uint32_t indexArray[3]; indexArray[0] = v0; indexArray[1] = v1; @@ -2460,8 +2313,7 @@ public: return addFace(indexArray, ignore, hashEdge); } - AddFaceResult::Enum addFace(const uint32_t *indices, bool ignore = false, bool hashEdge = true) - { + AddFaceResult::Enum addFace(const uint32_t *indices, bool ignore = false, bool hashEdge = true) { AddFaceResult::Enum result = AddFaceResult::OK; if (m_flags & MeshFlags::HasFaceGroups) m_faceGroups.push_back(kInvalidFaceGroup); @@ -2486,8 +2338,7 @@ public: return result; } - void createColocals() - { + void createColocals() { const uint32_t vertexCount = m_positions.size(); Array aabbs(MemTag::BVH); aabbs.resize(vertexCount); @@ -2515,7 +2366,7 @@ public: if (colocals.size() == 1) { // No colocals for this vertex. m_nextColocalVertex[i] = i; - continue; + continue; } m_colocalVertexCount += colocals.size(); // Link in ascending order. @@ -2527,8 +2378,7 @@ public: } // Check if the face duplicates any edges of any face already in the group. - bool faceDuplicatesGroupEdge(uint16_t group, uint32_t face) const - { + bool faceDuplicatesGroupEdge(uint16_t group, uint32_t face) const { for (FaceEdgeIterator edgeIt(this, face); !edgeIt.isDone(); edgeIt.advance()) { for (ColocalEdgeIterator colocalEdgeIt(this, edgeIt.vertex0(), edgeIt.vertex1()); !colocalEdgeIt.isDone(); colocalEdgeIt.advance()) { if (m_faceGroups[meshEdgeFace(colocalEdgeIt.edge())] == group) @@ -2538,8 +2388,7 @@ public: return false; } - void createFaceGroups() - { + void createFaceGroups() { uint32_t firstUnassignedFace = 0; uint16_t group = 0; Array growFaces; @@ -2619,8 +2468,7 @@ public: } } - void createBoundaries() - { + void createBoundaries() { const uint32_t edgeCount = m_indices.size(); const uint32_t vertexCount = m_positions.size(); m_oppositeEdges.resize(edgeCount); @@ -2650,8 +2498,7 @@ public: } } - void linkBoundaries() - { + void linkBoundaries() { const uint32_t edgeCount = m_indices.size(); HashMap vertexToEdgeMap(MemTag::Mesh, edgeCount); // Edge is index / 2 for (uint32_t i = 0; i < edgeCount; i++) { @@ -2744,8 +2591,7 @@ public: } /// Find edge, test all colocals. - uint32_t findEdge(uint32_t vertex0, uint32_t vertex1) const - { + uint32_t findEdge(uint32_t vertex0, uint32_t vertex1) const { uint32_t result = UINT32_MAX; if (m_nextColocalVertex.isEmpty()) { EdgeKey key(vertex0, vertex1); @@ -2784,8 +2630,7 @@ public: } #if XA_DEBUG_EXPORT_OBJ - void writeObjVertices(FILE *file) const - { + void writeObjVertices(FILE *file) const { for (uint32_t i = 0; i < m_positions.size(); i++) fprintf(file, "v %g %g %g\n", m_positions[i].x, m_positions[i].y, m_positions[i].z); if (m_flags & MeshFlags::HasNormals) { @@ -2796,8 +2641,7 @@ public: fprintf(file, "vt %g %g\n", m_texcoords[i].x, m_texcoords[i].y); } - void writeObjFace(FILE *file, uint32_t face) const - { + void writeObjFace(FILE *file, uint32_t face) const { fprintf(file, "f "); for (uint32_t j = 0; j < 3; j++) { const uint32_t index = m_indices[face * 3 + j] + 1; // 1-indexed @@ -2805,8 +2649,7 @@ public: } } - void writeObjBoundaryEges(FILE *file) const - { + void writeObjBoundaryEges(FILE *file) const { if (m_oppositeEdges.isEmpty()) return; // Boundaries haven't been created. fprintf(file, "o boundary_edges\n"); @@ -2817,8 +2660,7 @@ public: } } - void writeObjLinkedBoundaries(FILE *file) const - { + void writeObjLinkedBoundaries(FILE *file) const { if (m_oppositeEdges.isEmpty() || m_nextBoundaryEdges.isEmpty()) return; // Boundaries haven't been created and/or linked. Array boundaryLoops; @@ -2840,8 +2682,7 @@ public: } } - void writeObjFile(const char *filename) const - { + void writeObjFile(const char *filename) const { FILE *file; XA_FOPEN(file, filename, "w"); if (!file) @@ -2857,8 +2698,7 @@ public: } #endif - float computeSurfaceArea() const - { + float computeSurfaceArea() const { float area = 0; for (uint32_t f = 0; f < faceCount(); f++) area += computeFaceArea(f); @@ -2866,24 +2706,21 @@ public: return area; } - float computeParametricArea() const - { + float computeParametricArea() const { float area = 0; for (uint32_t f = 0; f < faceCount(); f++) area += computeFaceParametricArea(f); return fabsf(area); // May be negative, depends on texcoord winding. } - float computeFaceArea(uint32_t face) const - { + float computeFaceArea(uint32_t face) const { const Vector3 &p0 = m_positions[m_indices[face * 3 + 0]]; const Vector3 &p1 = m_positions[m_indices[face * 3 + 1]]; const Vector3 &p2 = m_positions[m_indices[face * 3 + 2]]; return length(cross(p1 - p0, p2 - p0)) * 0.5f; } - Vector3 computeFaceCentroid(uint32_t face) const - { + Vector3 computeFaceCentroid(uint32_t face) const { Vector3 sum(0.0f); for (uint32_t i = 0; i < 3; i++) sum += m_positions[m_indices[face * 3 + i]]; @@ -2892,8 +2729,7 @@ public: // Average of the edge midpoints weighted by the edge length. // I want a point inside the triangle, but closer to the cirumcenter. - Vector3 computeFaceCenter(uint32_t face) const - { + Vector3 computeFaceCenter(uint32_t face) const { const Vector3 &p0 = m_positions[m_indices[face * 3 + 0]]; const Vector3 &p1 = m_positions[m_indices[face * 3 + 1]]; const Vector3 &p2 = m_positions[m_indices[face * 3 + 2]]; @@ -2906,8 +2742,7 @@ public: return m0 + m1 + m2; } - Vector3 computeFaceNormal(uint32_t face) const - { + Vector3 computeFaceNormal(uint32_t face) const { const Vector3 &p0 = m_positions[m_indices[face * 3 + 0]]; const Vector3 &p1 = m_positions[m_indices[face * 3 + 1]]; const Vector3 &p2 = m_positions[m_indices[face * 3 + 2]]; @@ -2917,17 +2752,15 @@ public: return normalizeSafe(normalAreaScaled, Vector3(0, 0, 1), 0.0f); } - float computeFaceParametricArea(uint32_t face) const - { + float computeFaceParametricArea(uint32_t face) const { const Vector2 &t0 = m_texcoords[m_indices[face * 3 + 0]]; const Vector2 &t1 = m_texcoords[m_indices[face * 3 + 1]]; const Vector2 &t2 = m_texcoords[m_indices[face * 3 + 2]]; return triangleArea(t0, t1, t2); } - + // @@ This is not exactly accurate, we should compare the texture coordinates... - bool isSeam(uint32_t edge) const - { + bool isSeam(uint32_t edge) const { const uint32_t oppositeEdge = m_oppositeEdges[edge]; if (oppositeEdge == UINT32_MAX) return false; // boundary edge @@ -2938,8 +2771,7 @@ public: return m_indices[e0] != m_indices[oe1] || m_indices[e1] != m_indices[oe0]; } - bool isTextureSeam(uint32_t edge) const - { + bool isTextureSeam(uint32_t edge) const { const uint32_t oppositeEdge = m_oppositeEdges[edge]; if (oppositeEdge == UINT32_MAX) return false; // boundary edge @@ -2950,8 +2782,7 @@ public: return m_texcoords[m_indices[e0]] != m_texcoords[m_indices[oe1]] || m_texcoords[m_indices[e1]] != m_texcoords[m_indices[oe0]]; } - uint32_t firstColocal(uint32_t vertex) const - { + uint32_t firstColocal(uint32_t vertex) const { for (ColocalVertexIterator it(this, vertex); !it.isDone(); it.advance()) { if (it.vertex() < vertex) vertex = it.vertex(); @@ -2959,8 +2790,7 @@ public: return vertex; } - bool areColocal(uint32_t vertex0, uint32_t vertex1) const - { + bool areColocal(uint32_t vertex0, uint32_t vertex1) const { if (vertex0 == vertex1) return true; if (m_nextColocalVertex.isEmpty()) @@ -2982,17 +2812,32 @@ public: XA_INLINE uint32_t vertexCount() const { return m_positions.size(); } XA_INLINE uint32_t vertexAt(uint32_t i) const { return m_indices[i]; } XA_INLINE const Vector3 &position(uint32_t vertex) const { return m_positions[vertex]; } - XA_INLINE const Vector3 &normal(uint32_t vertex) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasNormals); return m_normals[vertex]; } + XA_INLINE const Vector3 &normal(uint32_t vertex) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasNormals); + return m_normals[vertex]; + } XA_INLINE const Vector2 &texcoord(uint32_t vertex) const { return m_texcoords[vertex]; } XA_INLINE Vector2 &texcoord(uint32_t vertex) { return m_texcoords[vertex]; } XA_INLINE const Vector2 *texcoords() const { return m_texcoords.data(); } XA_INLINE Vector2 *texcoords() { return m_texcoords.data(); } XA_INLINE uint32_t ignoredFaceCount() const { return m_ignoredFaceCount; } XA_INLINE uint32_t faceCount() const { return m_indices.size() / 3; } - XA_INLINE uint16_t faceGroupAt(uint32_t face) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroups[face]; } - XA_INLINE uint32_t faceGroupCount() const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroupFaceCounts.size(); } - XA_INLINE uint32_t faceGroupNextFace(uint32_t face) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroupNextFace[face]; } - XA_INLINE uint32_t faceGroupFaceCount(uint32_t group) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroupFaceCounts[group]; } + XA_INLINE uint16_t faceGroupAt(uint32_t face) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroups[face]; + } + XA_INLINE uint32_t faceGroupCount() const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroupFaceCounts.size(); + } + XA_INLINE uint32_t faceGroupNextFace(uint32_t face) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroupNextFace[face]; + } + XA_INLINE uint32_t faceGroupFaceCount(uint32_t group) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroupFaceCounts[group]; + } XA_INLINE const uint32_t *indices() const { return m_indices.data(); } XA_INLINE uint32_t indexCount() const { return m_indices.size(); } @@ -3027,49 +2872,45 @@ private: // Populated by linkBoundaries Array m_nextBoundaryEdges; // The index of the next boundary edge. UINT32_MAX if the edge is not a boundary edge. - struct EdgeKey - { + struct EdgeKey { EdgeKey() {} - EdgeKey(const EdgeKey &k) : v0(k.v0), v1(k.v1) {} - EdgeKey(uint32_t v0, uint32_t v1) : v0(v0), v1(v1) {} + EdgeKey(const EdgeKey &k) : + v0(k.v0), v1(k.v1) {} + EdgeKey(uint32_t v0, uint32_t v1) : + v0(v0), v1(v1) {} bool operator==(const EdgeKey &k) const { return v0 == k.v0 && v1 == k.v1; } uint32_t v0; uint32_t v1; }; - struct EdgeHash - { + struct EdgeHash { uint32_t operator()(const EdgeKey &k) const { return k.v0 * 32768u + k.v1; } }; HashMap m_edgeMap; public: - class BoundaryLoopEdgeIterator - { + class BoundaryLoopEdgeIterator { public: - BoundaryLoopEdgeIterator(const Mesh *mesh, uint32_t edge) : m_mesh(mesh), m_first(UINT32_MAX), m_current(edge) {} + BoundaryLoopEdgeIterator(const Mesh *mesh, uint32_t edge) : + m_mesh(mesh), m_first(UINT32_MAX), m_current(edge) {} - void advance() - { + void advance() { if (m_first == UINT32_MAX) m_first = m_current; m_current = m_mesh->m_nextBoundaryEdges[m_current]; } - bool isDone() const - { + bool isDone() const { return m_first == m_current || m_current == UINT32_MAX; } - uint32_t edge() const - { + uint32_t edge() const { return m_current; } - uint32_t nextEdge() const - { + uint32_t nextEdge() const { return m_mesh->m_nextBoundaryEdges[m_current]; } @@ -3079,31 +2920,27 @@ public: uint32_t m_current; }; - class ColocalVertexIterator - { + class ColocalVertexIterator { public: - ColocalVertexIterator(const Mesh *mesh, uint32_t v) : m_mesh(mesh), m_first(UINT32_MAX), m_current(v) {} + ColocalVertexIterator(const Mesh *mesh, uint32_t v) : + m_mesh(mesh), m_first(UINT32_MAX), m_current(v) {} - void advance() - { + void advance() { if (m_first == UINT32_MAX) m_first = m_current; if (!m_mesh->m_nextColocalVertex.isEmpty()) m_current = m_mesh->m_nextColocalVertex[m_current]; } - bool isDone() const - { + bool isDone() const { return m_first == m_current; } - uint32_t vertex() const - { + uint32_t vertex() const { return m_current; } - const Vector3 *pos() const - { + const Vector3 *pos() const { return &m_mesh->m_positions[m_current]; } @@ -3113,39 +2950,33 @@ public: uint32_t m_current; }; - class ColocalEdgeIterator - { + class ColocalEdgeIterator { public: - ColocalEdgeIterator(const Mesh *mesh, uint32_t vertex0, uint32_t vertex1) : m_mesh(mesh), m_vertex0It(mesh, vertex0), m_vertex1It(mesh, vertex1), m_vertex1(vertex1) - { + ColocalEdgeIterator(const Mesh *mesh, uint32_t vertex0, uint32_t vertex1) : + m_mesh(mesh), m_vertex0It(mesh, vertex0), m_vertex1It(mesh, vertex1), m_vertex1(vertex1) { do { if (!resetElement()) { advanceVertex1(); - } - else { + } else { break; } } while (!isDone()); } - void advance() - { + void advance() { advanceElement(); } - bool isDone() const - { + bool isDone() const { return m_vertex0It.isDone() && m_vertex1It.isDone() && m_edge == UINT32_MAX; } - uint32_t edge() const - { + uint32_t edge() const { return m_edge; } private: - bool resetElement() - { + bool resetElement() { m_edge = m_mesh->m_edgeMap.get(Mesh::EdgeKey(m_vertex0It.vertex(), m_vertex1It.vertex())); while (m_edge != UINT32_MAX) { if (!isIgnoredFace()) @@ -3158,8 +2989,7 @@ public: return true; } - void advanceElement() - { + void advanceElement() { for (;;) { m_edge = m_mesh->m_edgeMap.getNext(m_edge); if (m_edge == UINT32_MAX) @@ -3171,17 +3001,15 @@ public: advanceVertex1(); } - void advanceVertex1() - { + void advanceVertex1() { auto successful = false; - while (!successful) { + while (!successful) { m_vertex1It.advance(); if (m_vertex1It.isDone()) { if (!m_vertex0It.isDone()) { m_vertex0It.advance(); m_vertex1It = ColocalVertexIterator(m_mesh, m_vertex1); - } - else { + } else { return; } } @@ -3189,8 +3017,7 @@ public: } } - bool isIgnoredFace() const - { + bool isIgnoredFace() const { return m_mesh->m_faceIgnore[meshEdgeFace(m_edge)]; } @@ -3200,24 +3027,21 @@ public: uint32_t m_edge; }; - class FaceEdgeIterator - { + class FaceEdgeIterator { public: - FaceEdgeIterator (const Mesh *mesh, uint32_t face) : m_mesh(mesh), m_face(face), m_relativeEdge(0) - { + FaceEdgeIterator(const Mesh *mesh, uint32_t face) : + m_mesh(mesh), m_face(face), m_relativeEdge(0) { m_edge = m_face * 3; } - void advance() - { + void advance() { if (m_relativeEdge < 3) { m_edge++; m_relativeEdge++; } } - bool isDone() const - { + bool isDone() const { return m_relativeEdge == 3; } @@ -3228,9 +3052,8 @@ public: uint32_t relativeEdge() const { return m_relativeEdge; } uint32_t face() const { return m_face; } uint32_t oppositeEdge() const { return m_mesh->m_oppositeEdges[m_edge]; } - - uint32_t oppositeFace() const - { + + uint32_t oppositeFace() const { const uint32_t oedge = m_mesh->m_oppositeEdges[m_edge]; if (oedge == UINT32_MAX) return UINT32_MAX; @@ -3253,27 +3076,23 @@ public: uint32_t m_relativeEdge; }; - class GroupFaceIterator - { + class GroupFaceIterator { public: - GroupFaceIterator(const Mesh *mesh, uint32_t group) : m_mesh(mesh) - { + GroupFaceIterator(const Mesh *mesh, uint32_t group) : + m_mesh(mesh) { XA_DEBUG_ASSERT(group != UINT32_MAX); m_current = mesh->m_faceGroupFirstFace[group]; } - void advance() - { + void advance() { m_current = m_mesh->m_faceGroupNextFace[m_current]; } - bool isDone() const - { + bool isDone() const { return m_current == UINT32_MAX; } - uint32_t face() const - { + uint32_t face() const { return m_current; } @@ -3285,8 +3104,7 @@ public: constexpr uint16_t Mesh::kInvalidFaceGroup; -static bool meshCloseHole(Mesh *mesh, const Array &holeVertices, const Vector3 &normal) -{ +static bool meshCloseHole(Mesh *mesh, const Array &holeVertices, const Vector3 &normal) { #if XA_CLOSE_HOLES_CHECK_EDGE_INTERSECTION const uint32_t faceCount = mesh->faceCount(); #endif @@ -3412,8 +3230,7 @@ static bool meshCloseHole(Mesh *mesh, const Array &holeVertices, const return true; } -static bool meshCloseHoles(Mesh *mesh, const Array &boundaryLoops, const Vector3 &normal, uint32_t *holeCount, Array *holeFaceCounts) -{ +static bool meshCloseHoles(Mesh *mesh, const Array &boundaryLoops, const Vector3 &normal, uint32_t *holeCount, Array *holeFaceCounts) { if (holeFaceCounts) holeFaceCounts->clear(); // Compute lengths. @@ -3469,8 +3286,7 @@ static bool meshCloseHoles(Mesh *mesh, const Array &boundaryLoops, con return result; } -static bool meshIsPlanar(const Mesh &mesh) -{ +static bool meshIsPlanar(const Mesh &mesh) { const Vector3 p1 = mesh.position(mesh.vertexAt(0)); const Vector3 p2 = mesh.position(mesh.vertexAt(1)); const Vector3 p3 = mesh.position(mesh.vertexAt(2)); @@ -3496,14 +3312,12 @@ Fixing T-junctions. - Split edge. */ -struct SplitEdge -{ +struct SplitEdge { uint32_t edge; float t; uint32_t vertex; - bool operator<(const SplitEdge &other) const - { + bool operator<(const SplitEdge &other) const { if (edge < other.edge) return true; else if (edge == other.edge) { @@ -3515,8 +3329,7 @@ struct SplitEdge }; // Returns nullptr if there were no t-junctions to fix. -static Mesh *meshFixTJunctions(const Mesh &inputMesh, bool *duplicatedEdge, bool *failed, uint32_t *fixedTJunctionsCount) -{ +static Mesh *meshFixTJunctions(const Mesh &inputMesh, bool *duplicatedEdge, bool *failed, uint32_t *fixedTJunctionsCount) { if (duplicatedEdge) *duplicatedEdge = false; if (failed) @@ -3591,8 +3404,7 @@ static Mesh *meshFixTJunctions(const Mesh &inputMesh, bool *duplicatedEdge, bool } // boundaryLoops are the first edges for each boundary loop. -static void meshGetBoundaryLoops(const Mesh &mesh, Array &boundaryLoops) -{ +static void meshGetBoundaryLoops(const Mesh &mesh, Array &boundaryLoops) { const uint32_t edgeCount = mesh.edgeCount(); BitArray bitFlags(edgeCount); bitFlags.zeroOutMemory(); @@ -3607,26 +3419,23 @@ static void meshGetBoundaryLoops(const Mesh &mesh, Array &boundaryLoop } } -struct Progress -{ - Progress(ProgressCategory::Enum category, ProgressFunc func, void *userData, uint32_t maxValue) : value(0), cancel(false), m_category(category), m_func(func), m_userData(userData), m_maxValue(maxValue), m_progress(0) - { +struct Progress { + Progress(ProgressCategory::Enum category, ProgressFunc func, void *userData, uint32_t maxValue) : + value(0), cancel(false), m_category(category), m_func(func), m_userData(userData), m_maxValue(maxValue), m_progress(0) { if (m_func) { if (!m_func(category, 0, userData)) cancel = true; } } - ~Progress() - { + ~Progress() { if (m_func) { if (!m_func(m_category, 100, m_userData)) cancel = true; } } - void update() - { + void update() { if (!m_func) return; m_mutex.lock(); @@ -3639,8 +3448,7 @@ struct Progress m_mutex.unlock(); } - void setMaxValue(uint32_t maxValue) - { + void setMaxValue(uint32_t maxValue) { m_mutex.lock(); m_maxValue = maxValue; m_mutex.unlock(); @@ -3658,32 +3466,31 @@ private: std::mutex m_mutex; }; -struct Spinlock -{ - void lock() { while(m_lock.test_and_set(std::memory_order_acquire)) {} } +struct Spinlock { + void lock() { + while (m_lock.test_and_set(std::memory_order_acquire)) { + } + } void unlock() { m_lock.clear(std::memory_order_release); } private: std::atomic_flag m_lock = ATOMIC_FLAG_INIT; }; -struct TaskGroupHandle -{ +struct TaskGroupHandle { uint32_t value = UINT32_MAX; }; -struct Task -{ +struct Task { void (*func)(void *userData); void *userData; }; #if XA_MULTITHREADED -class TaskScheduler -{ +class TaskScheduler { public: - TaskScheduler() : m_shutdown(false) - { + TaskScheduler() : + m_shutdown(false) { m_threadIndex = 0; // Max with current task scheduler usage is 1 per thread + 1 deep nesting, but allow for some slop. m_maxGroups = std::thread::hardware_concurrency() * 4; @@ -3701,8 +3508,7 @@ public: } } - ~TaskScheduler() - { + ~TaskScheduler() { m_shutdown = true; for (uint32_t i = 0; i < m_workers.size(); i++) { Worker &worker = m_workers[i]; @@ -3720,13 +3526,11 @@ public: XA_FREE(m_groups); } - uint32_t threadCount() const - { + uint32_t threadCount() const { return max(1u, std::thread::hardware_concurrency()); // Including the main thread. } - TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) - { + TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) { // Claim the first free group. for (uint32_t i = 0; i < m_maxGroups; i++) { TaskGroup &group = m_groups[i]; @@ -3748,8 +3552,7 @@ public: return handle; } - void run(TaskGroupHandle handle, Task task) - { + void run(TaskGroupHandle handle, Task task) { XA_DEBUG_ASSERT(handle.value != UINT32_MAX); TaskGroup &group = m_groups[handle.value]; group.queueLock.lock(); @@ -3763,8 +3566,7 @@ public: } } - void wait(TaskGroupHandle *handle) - { + void wait(TaskGroupHandle *handle) { if (handle->value == UINT32_MAX) { XA_DEBUG_ASSERT(false); return; @@ -3792,8 +3594,7 @@ public: static uint32_t currentThreadIndex() { return m_threadIndex; } private: - struct TaskGroup - { + struct TaskGroup { std::atomic free; Array queue; // Items are never removed. queueHead is incremented to pop items. uint32_t queueHead = 0; @@ -3801,8 +3602,7 @@ private: std::atomic ref; // Increment when a task is enqueued, decrement when a task finishes. }; - struct Worker - { + struct Worker { std::thread *thread = nullptr; std::mutex mutex; std::condition_variable cv; @@ -3815,12 +3615,11 @@ private: std::atomic m_shutdown; static thread_local uint32_t m_threadIndex; - static void workerThread(TaskScheduler *scheduler, Worker *worker, uint32_t threadIndex) - { + static void workerThread(TaskScheduler *scheduler, Worker *worker, uint32_t threadIndex) { m_threadIndex = threadIndex; std::unique_lock lock(worker->mutex); for (;;) { - worker->cv.wait(lock, [=]{ return worker->wakeup.load(); }); + worker->cv.wait(lock, [=] { return worker->wakeup.load(); }); worker->wakeup = false; for (;;) { if (scheduler->m_shutdown) @@ -3851,22 +3650,18 @@ private: thread_local uint32_t TaskScheduler::m_threadIndex; #else -class TaskScheduler -{ +class TaskScheduler { public: - ~TaskScheduler() - { + ~TaskScheduler() { for (uint32_t i = 0; i < m_groups.size(); i++) destroyGroup({ i }); } - uint32_t threadCount() const - { + uint32_t threadCount() const { return 1; } - TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) - { + TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) { TaskGroup *group = XA_NEW(MemTag::Default, TaskGroup); group->queue.reserve(reserveSize); m_groups.push_back(group); @@ -3875,13 +3670,11 @@ public: return handle; } - void run(TaskGroupHandle handle, Task task) - { + void run(TaskGroupHandle handle, Task task) { m_groups[handle.value]->queue.push_back(task); } - void wait(TaskGroupHandle *handle) - { + void wait(TaskGroupHandle *handle) { if (handle->value == UINT32_MAX) { XA_DEBUG_ASSERT(false); return; @@ -3897,8 +3690,7 @@ public: static uint32_t currentThreadIndex() { return 0; } private: - void destroyGroup(TaskGroupHandle handle) - { + void destroyGroup(TaskGroupHandle handle) { TaskGroup *group = m_groups[handle.value]; if (group) { group->~TaskGroup(); @@ -3907,8 +3699,7 @@ private: } } - struct TaskGroup - { + struct TaskGroup { Array queue; }; @@ -3921,8 +3712,7 @@ const uint8_t TGA_TYPE_RGB = 2; const uint8_t TGA_ORIGIN_UPPER = 0x20; #pragma pack(push, 1) -struct TgaHeader -{ +struct TgaHeader { uint8_t id_length; uint8_t colormap_type; uint8_t image_type; @@ -3939,8 +3729,7 @@ struct TgaHeader }; #pragma pack(pop) -static void WriteTga(const char *filename, const uint8_t *data, uint32_t width, uint32_t height) -{ +static void WriteTga(const char *filename, const uint8_t *data, uint32_t width, uint32_t height) { XA_DEBUG_ASSERT(sizeof(TgaHeader) == TgaHeader::Size); FILE *f; XA_FOPEN(f, filename, "wb"); @@ -3965,12 +3754,10 @@ static void WriteTga(const char *filename, const uint8_t *data, uint32_t width, } #endif -template -class ThreadLocal -{ +template +class ThreadLocal { public: - ThreadLocal() - { + ThreadLocal() { #if XA_MULTITHREADED const uint32_t n = std::thread::hardware_concurrency(); #else @@ -3981,8 +3768,7 @@ public: new (&m_array[i]) T; } - ~ThreadLocal() - { + ~ThreadLocal() { #if XA_MULTITHREADED const uint32_t n = std::thread::hardware_concurrency(); #else @@ -3993,8 +3779,7 @@ public: XA_FREE(m_array); } - T &get() const - { + T &get() const { return m_array[TaskScheduler::currentThreadIndex()]; } @@ -4002,11 +3787,9 @@ private: T *m_array; }; -class UniformGrid2 -{ +class UniformGrid2 { public: - void reset(const Vector2 *positions, const uint32_t *indices = nullptr, uint32_t reserveEdgeCount = 0) - { + void reset(const Vector2 *positions, const uint32_t *indices = nullptr, uint32_t reserveEdgeCount = 0) { m_edges.clear(); if (reserveEdgeCount > 0) m_edges.reserve(reserveEdgeCount); @@ -4015,14 +3798,12 @@ public: m_cellDataOffsets.clear(); } - void append(uint32_t edge) - { + void append(uint32_t edge) { XA_DEBUG_ASSERT(m_cellDataOffsets.isEmpty()); m_edges.push_back(edge); } - bool intersect(Vector2 v1, Vector2 v2, float epsilon) - { + bool intersect(Vector2 v1, Vector2 v2, float epsilon) { const uint32_t edgeCount = m_edges.size(); bool bruteForce = edgeCount <= 64; if (!bruteForce && m_cellDataOffsets.isEmpty()) @@ -4048,8 +3829,7 @@ public: return false; } - bool intersectSelf(float epsilon) - { + bool intersectSelf(float epsilon) { const uint32_t edgeCount = m_edges.size(); bool bruteForce = edgeCount <= 64; if (!bruteForce && m_cellDataOffsets.isEmpty()) @@ -4079,8 +3859,7 @@ public: } #if XA_DEBUG_EXPORT_BOUNDARY_GRID - void debugExport(const char *filename) - { + void debugExport(const char *filename) { Array image; image.resize(m_gridWidth * m_gridHeight * 3); for (uint32_t y = 0; y < m_gridHeight; y++) { @@ -4102,8 +3881,7 @@ public: #endif private: - bool createGrid() - { + bool createGrid() { // Compute edge extents. Min will be the grid origin. const uint32_t edgeCount = m_edges.size(); Extents2 edgeExtents; @@ -4155,8 +3933,7 @@ private: return true; } - void computePotentialEdges(Vector2 p1, Vector2 p2) - { + void computePotentialEdges(Vector2 p1, Vector2 p2) { m_potentialEdges.clear(); traverse(p1, p2); for (uint32_t j = 0; j < m_traversedCellOffsets.size(); j++) { @@ -4174,8 +3951,7 @@ private: } // "A Fast Voxel Traversal Algorithm for Ray Tracing" - void traverse(Vector2 p1, Vector2 p2) - { + void traverse(Vector2 p1, Vector2 p2) { const Vector2 dir = p2 - p1; const Vector2 normal = normalizeSafe(dir, Vector2(0.0f), kEpsilon); const int stepX = dir.x >= 0 ? 1 : -1; @@ -4196,14 +3972,12 @@ private: if (normal.x > kEpsilon || normal.x < -kEpsilon) { tMaxX = (distToNextCellX * stepX) / normal.x; tDeltaX = (m_cellSize * stepX) / normal.x; - } - else + } else tMaxX = tDeltaX = FLT_MAX; if (normal.y > kEpsilon || normal.y < -kEpsilon) { tMaxY = (distToNextCellY * stepY) / normal.y; tDeltaY = (m_cellSize * stepY) / normal.y; - } - else + } else tMaxY = tDeltaY = FLT_MAX; m_traversedCellOffsets.clear(); m_traversedCellOffsets.push_back(firstCell[0] + firstCell[1] * m_gridWidth); @@ -4230,8 +4004,7 @@ private: } } - bool edgesIntersect(uint32_t edge1, uint32_t edge2, float epsilon) const - { + bool edgesIntersect(uint32_t edge1, uint32_t edge2, float epsilon) const { if (edge1 == edge2) return false; const uint32_t ai[2] = { vertexAt(meshEdgeIndex0(edge1)), vertexAt(meshEdgeIndex1(edge1)) }; @@ -4242,28 +4015,23 @@ private: return linesIntersect(m_positions[ai[0]], m_positions[ai[1]], m_positions[bi[0]], m_positions[bi[1]], epsilon); } - uint32_t cellX(float x) const - { + uint32_t cellX(float x) const { return min((uint32_t)max(0.0f, (x - m_gridOrigin.x) / m_cellSize), m_gridWidth - 1u); } - uint32_t cellY(float y) const - { + uint32_t cellY(float y) const { return min((uint32_t)max(0.0f, (y - m_gridOrigin.y) / m_cellSize), m_gridHeight - 1u); } - Vector2 edgePosition0(uint32_t edge) const - { + Vector2 edgePosition0(uint32_t edge) const { return m_positions[vertexAt(meshEdgeIndex0(edge))]; } - Vector2 edgePosition1(uint32_t edge) const - { + Vector2 edgePosition1(uint32_t edge) const { return m_positions[vertexAt(meshEdgeIndex1(edge))]; } - uint32_t vertexAt(uint32_t index) const - { + uint32_t vertexAt(uint32_t index) const { return m_indices ? m_indices[index] : index; } @@ -4279,34 +4047,29 @@ private: Array m_traversedCellOffsets; }; -struct UvMeshChart -{ +struct UvMeshChart { Array faces; Array indices; uint32_t material; }; -struct UvMesh -{ +struct UvMesh { UvMeshDecl decl; Array indices; Array charts; Array vertexToChartMap; }; -struct UvMeshInstance -{ +struct UvMeshInstance { UvMesh *mesh; Array texcoords; bool rotateCharts; }; namespace raster { -class ClippedTriangle -{ +class ClippedTriangle { public: - ClippedTriangle(const Vector2 &a, const Vector2 &b, const Vector2 &c) - { + ClippedTriangle(const Vector2 &a, const Vector2 &b, const Vector2 &c) { m_numVertices = 3; m_activeVertexBuffer = 0; m_verticesA[0] = a; @@ -4316,20 +4079,19 @@ public: m_vertexBuffers[1] = m_verticesB; } - void clipHorizontalPlane(float offset, float clipdirection) - { - Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; + void clipHorizontalPlane(float offset, float clipdirection) { + Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; m_activeVertexBuffer ^= 1; Vector2 *v2 = m_vertexBuffers[m_activeVertexBuffer]; v[m_numVertices] = v[0]; - float dy2, dy1 = offset - v[0].y; - int dy2in, dy1in = clipdirection * dy1 >= 0; - uint32_t p = 0; + float dy2, dy1 = offset - v[0].y; + int dy2in, dy1in = clipdirection * dy1 >= 0; + uint32_t p = 0; for (uint32_t k = 0; k < m_numVertices; k++) { - dy2 = offset - v[k + 1].y; + dy2 = offset - v[k + 1].y; dy2in = clipdirection * dy2 >= 0; if (dy1in) v2[p++] = v[k]; - if ( dy1in + dy2in == 1 ) { // not both in/out + if (dy1in + dy2in == 1) { // not both in/out float dx = v[k + 1].x - v[k].x; float dy = v[k + 1].y - v[k].y; v2[p++] = Vector2(v[k].x + dy1 * (dx / dy), offset); @@ -4340,20 +4102,19 @@ public: m_numVertices = p; } - void clipVerticalPlane(float offset, float clipdirection) - { - Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; + void clipVerticalPlane(float offset, float clipdirection) { + Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; m_activeVertexBuffer ^= 1; Vector2 *v2 = m_vertexBuffers[m_activeVertexBuffer]; v[m_numVertices] = v[0]; - float dx2, dx1 = offset - v[0].x; - int dx2in, dx1in = clipdirection * dx1 >= 0; - uint32_t p = 0; + float dx2, dx1 = offset - v[0].x; + int dx2in, dx1in = clipdirection * dx1 >= 0; + uint32_t p = 0; for (uint32_t k = 0; k < m_numVertices; k++) { dx2 = offset - v[k + 1].x; dx2in = clipdirection * dx2 >= 0; if (dx1in) v2[p++] = v[k]; - if ( dx1in + dx2in == 1 ) { // not both in/out + if (dx1in + dx2in == 1) { // not both in/out float dx = v[k + 1].x - v[k].x; float dy = v[k + 1].y - v[k].y; v2[p++] = Vector2(offset, v[k].y + dx1 * (dy / dx)); @@ -4364,9 +4125,8 @@ public: m_numVertices = p; } - void computeArea() - { - Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; + void computeArea() { + Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; v[m_numVertices] = v[0]; m_area = 0; float centroidx = 0, centroidy = 0; @@ -4380,8 +4140,7 @@ public: m_area = 0.5f * fabsf(m_area); } - void clipAABox(float x0, float y0, float x1, float y1) - { + void clipAABox(float x0, float y0, float x1, float y1) { clipVerticalPlane(x0, -1); clipHorizontalPlane(y0, -1); clipVerticalPlane(x1, 1); @@ -4389,8 +4148,7 @@ public: computeArea(); } - float area() const - { + float area() const { return m_area; } @@ -4407,10 +4165,8 @@ private: typedef bool (*SamplingCallback)(void *param, int x, int y); /// A triangle for rasterization. -struct Triangle -{ - Triangle(const Vector2 &v0, const Vector2 &v1, const Vector2 &v2) - { +struct Triangle { + Triangle(const Vector2 &v0, const Vector2 &v1, const Vector2 &v2) { // Init vertices. this->v1 = v0; this->v2 = v2; @@ -4422,8 +4178,7 @@ struct Triangle computeUnitInwardNormals(); } - bool isValid() - { + bool isValid() { const Vector2 e0 = v3 - v1; const Vector2 e1 = v2 - v1; const float area = e0.y * e1.x - e1.y * e0.x; @@ -4431,18 +4186,17 @@ struct Triangle } // extents has to be multiple of BK_SIZE!! - bool drawAA(const Vector2 &extents, SamplingCallback cb, void *param) - { - const float PX_INSIDE = 1.0f/sqrtf(2.0f); - const float PX_OUTSIDE = -1.0f/sqrtf(2.0f); + bool drawAA(const Vector2 &extents, SamplingCallback cb, void *param) { + const float PX_INSIDE = 1.0f / sqrtf(2.0f); + const float PX_OUTSIDE = -1.0f / sqrtf(2.0f); const float BK_SIZE = 8; - const float BK_INSIDE = sqrtf(BK_SIZE*BK_SIZE/2.0f); - const float BK_OUTSIDE = -sqrtf(BK_SIZE*BK_SIZE/2.0f); + const float BK_INSIDE = sqrtf(BK_SIZE * BK_SIZE / 2.0f); + const float BK_OUTSIDE = -sqrtf(BK_SIZE * BK_SIZE / 2.0f); // Bounding rectangle float minx = floorf(max(min3(v1.x, v2.x, v3.x), 0.0f)); float miny = floorf(max(min3(v1.y, v2.y, v3.y), 0.0f)); - float maxx = ceilf( min(max3(v1.x, v2.x, v3.x), extents.x - 1.0f)); - float maxy = ceilf( min(max3(v1.y, v2.y, v3.y), extents.y - 1.0f)); + float maxx = ceilf(min(max3(v1.x, v2.x, v3.x), extents.x - 1.0f)); + float maxy = ceilf(min(max3(v1.y, v2.y, v3.y), extents.y - 1.0f)); // There's no reason to align the blocks to the viewport, instead we align them to the origin of the triangle bounds. minx = floorf(minx); miny = floorf(miny); @@ -4467,9 +4221,9 @@ struct Triangle float bC = C2 + n2.x * xc + n2.y * yc; float cC = C3 + n3.x * xc + n3.y * yc; // Skip block when outside an edge - if ( (aC <= BK_OUTSIDE) || (bC <= BK_OUTSIDE) || (cC <= BK_OUTSIDE) ) continue; + if ((aC <= BK_OUTSIDE) || (bC <= BK_OUTSIDE) || (cC <= BK_OUTSIDE)) continue; // Accept whole block when totally covered - if ( (aC >= BK_INSIDE) && (bC >= BK_INSIDE) && (cC >= BK_INSIDE) ) { + if ((aC >= BK_INSIDE) && (bC >= BK_INSIDE) && (cC >= BK_INSIDE)) { for (float y = y0; y < y0 + BK_SIZE; y++) { for (float x = x0; x < x0 + BK_SIZE; x++) { if (!cb(param, (int)x, (int)y)) @@ -4512,10 +4266,9 @@ struct Triangle } private: - void flipBackface() - { + void flipBackface() { // check if triangle is backfacing, if so, swap two vertices - if ( ((v3.x - v1.x) * (v2.y - v1.y) - (v3.y - v1.y) * (v2.x - v1.x)) < 0 ) { + if (((v3.x - v1.x) * (v2.y - v1.y) - (v3.y - v1.y) * (v2.x - v1.x)) < 0) { Vector2 hv = v1; v1 = v2; v2 = hv; // swap pos @@ -4523,8 +4276,7 @@ private: } // compute unit inward normals for each edge. - void computeUnitInwardNormals() - { + void computeUnitInwardNormals() { n1 = v1 - v2; n1 = Vector2(-n1.y, n1.x); n1 = n1 * (1.0f / sqrtf(dot(n1, n1))); @@ -4542,8 +4294,7 @@ private: }; // Process the given triangle. Returns false if rasterization was interrupted by the callback. -static bool drawTriangle(const Vector2 &extents, const Vector2 v[3], SamplingCallback cb, void *param) -{ +static bool drawTriangle(const Vector2 &extents, const Vector2 v[3], SamplingCallback cb, void *param) { Triangle tri(v[0], v[1], v[2]); // @@ It would be nice to have a conservative drawing mode that enlarges the triangle extents by one texel and is able to handle degenerate triangles. // @@ Maybe the simplest thing to do would be raster triangle edges. @@ -4566,18 +4317,16 @@ namespace sparse { * elements for each row of the matrix. As with the FullVector the * dimension of the matrix is constant. **/ -class Matrix -{ +class Matrix { public: // An element of the sparse array. - struct Coefficient - { - uint32_t x; // column + struct Coefficient { + uint32_t x; // column float v; // value }; - Matrix(uint32_t d) : m_width(d), m_array(MemTag::Matrix) - { + Matrix(uint32_t d) : + m_width(d), m_array(MemTag::Matrix) { m_array.resize(d); m_array.runCtors(); #if XA_DEBUG_HEAP @@ -4585,9 +4334,9 @@ public: m_array[i].setMemTag(MemTag::Matrix); #endif } - - Matrix(uint32_t w, uint32_t h) : m_width(w), m_array(MemTag::Matrix) - { + + Matrix(uint32_t w, uint32_t h) : + m_width(w), m_array(MemTag::Matrix) { m_array.resize(h); m_array.runCtors(); #if XA_DEBUG_HEAP @@ -4595,9 +4344,8 @@ public: m_array[i].setMemTag(MemTag::Matrix); #endif } - - ~Matrix() - { + + ~Matrix() { m_array.runDtors(); } @@ -4608,10 +4356,9 @@ public: bool isSquare() const { return width() == height(); } // x is column, y is row - float getCoefficient(uint32_t x, uint32_t y) const - { - XA_DEBUG_ASSERT( x < width() ); - XA_DEBUG_ASSERT( y < height() ); + float getCoefficient(uint32_t x, uint32_t y) const { + XA_DEBUG_ASSERT(x < width()); + XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); for (uint32_t i = 0; i < count; i++) { if (m_array[y][i].x == x) return m_array[y][i].v; @@ -4619,10 +4366,9 @@ public: return 0.0f; } - void setCoefficient(uint32_t x, uint32_t y, float f) - { - XA_DEBUG_ASSERT( x < width() ); - XA_DEBUG_ASSERT( y < height() ); + void setCoefficient(uint32_t x, uint32_t y, float f) { + XA_DEBUG_ASSERT(x < width()); + XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); for (uint32_t i = 0; i < count; i++) { if (m_array[y][i].x == x) { @@ -4632,13 +4378,12 @@ public: } if (f != 0.0f) { Coefficient c = { x, f }; - m_array[y].push_back( c ); + m_array[y].push_back(c); } } - float dotRow(uint32_t y, const FullVector &v) const - { - XA_DEBUG_ASSERT( y < height() ); + float dotRow(uint32_t y, const FullVector &v) const { + XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); float sum = 0; for (uint32_t i = 0; i < count; i++) { @@ -4647,8 +4392,7 @@ public: return sum; } - void madRow(uint32_t y, float alpha, FullVector &v) const - { + void madRow(uint32_t y, float alpha, FullVector &v) const { XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); for (uint32_t i = 0; i < count; i++) { @@ -4656,9 +4400,8 @@ public: } } - void clearRow(uint32_t y) - { - XA_DEBUG_ASSERT( y < height() ); + void clearRow(uint32_t y) { + XA_DEBUG_ASSERT(y < height()); m_array[y].clear(); } @@ -4669,12 +4412,11 @@ private: const uint32_t m_width; /// Array of matrix elements. - Array< Array > m_array; + Array> m_array; }; // y = a * x + y -static void saxpy(float a, const FullVector &x, FullVector &y) -{ +static void saxpy(float a, const FullVector &x, FullVector &y) { XA_DEBUG_ASSERT(x.dimension() == y.dimension()); const uint32_t dim = x.dimension(); for (uint32_t i = 0; i < dim; i++) { @@ -4682,8 +4424,7 @@ static void saxpy(float a, const FullVector &x, FullVector &y) } } -static void copy(const FullVector &x, FullVector &y) -{ +static void copy(const FullVector &x, FullVector &y) { XA_DEBUG_ASSERT(x.dimension() == y.dimension()); const uint32_t dim = x.dimension(); for (uint32_t i = 0; i < dim; i++) { @@ -4691,16 +4432,14 @@ static void copy(const FullVector &x, FullVector &y) } } -static void scal(float a, FullVector &x) -{ +static void scal(float a, FullVector &x) { const uint32_t dim = x.dimension(); for (uint32_t i = 0; i < dim; i++) { x[i] *= a; } } -static float dot(const FullVector &x, const FullVector &y) -{ +static float dot(const FullVector &x, const FullVector &y) { XA_DEBUG_ASSERT(x.dimension() == y.dimension()); const uint32_t dim = x.dimension(); float sum = 0; @@ -4711,24 +4450,22 @@ static float dot(const FullVector &x, const FullVector &y) } // y = M * x -static void mult(const Matrix &M, const FullVector &x, FullVector &y) -{ +static void mult(const Matrix &M, const FullVector &x, FullVector &y) { uint32_t w = M.width(); uint32_t h = M.height(); - XA_DEBUG_ASSERT( w == x.dimension() ); + XA_DEBUG_ASSERT(w == x.dimension()); XA_UNUSED(w); - XA_DEBUG_ASSERT( h == y.dimension() ); + XA_DEBUG_ASSERT(h == y.dimension()); for (uint32_t i = 0; i < h; i++) y[i] = M.dotRow(i, x); } // y = alpha*A*x + beta*y -static void sgemv(float alpha, const Matrix &A, const FullVector &x, float beta, FullVector &y) -{ +static void sgemv(float alpha, const Matrix &A, const FullVector &x, float beta, FullVector &y) { const uint32_t w = A.width(); const uint32_t h = A.height(); - XA_DEBUG_ASSERT( w == x.dimension() ); - XA_DEBUG_ASSERT( h == y.dimension() ); + XA_DEBUG_ASSERT(w == x.dimension()); + XA_DEBUG_ASSERT(h == y.dimension()); XA_UNUSED(w); XA_UNUSED(h); for (uint32_t i = 0; i < h; i++) @@ -4736,8 +4473,7 @@ static void sgemv(float alpha, const Matrix &A, const FullVector &x, float beta, } // dot y-row of A by x-column of B -static float dotRowColumn(int y, const Matrix &A, int x, const Matrix &B) -{ +static float dotRowColumn(int y, const Matrix &A, int x, const Matrix &B) { const Array &row = A.getRow(y); const uint32_t count = row.size(); float sum = 0.0f; @@ -4748,8 +4484,7 @@ static float dotRowColumn(int y, const Matrix &A, int x, const Matrix &B) return sum; } -static void transpose(const Matrix &A, Matrix &B) -{ +static void transpose(const Matrix &A, Matrix &B) { XA_DEBUG_ASSERT(A.width() == B.height()); XA_DEBUG_ASSERT(B.width() == A.height()); const uint32_t w = A.width(); @@ -4768,8 +4503,7 @@ static void transpose(const Matrix &A, Matrix &B) } } -static void sgemm(float alpha, const Matrix &A, const Matrix &B, float beta, Matrix &C) -{ +static void sgemm(float alpha, const Matrix &A, const Matrix &B, float beta, Matrix &C) { const uint32_t w = C.width(); const uint32_t h = C.height(); #if XA_DEBUG @@ -4793,8 +4527,7 @@ static void sgemm(float alpha, const Matrix &A, const Matrix &B, float beta, Mat } // C = A * B -static void mult(const Matrix &A, const Matrix &B, Matrix &C) -{ +static void mult(const Matrix &A, const Matrix &B, Matrix &C) { sgemm(1.0f, A, B, 0.0f, C); } @@ -4804,22 +4537,19 @@ namespace segment { // - Insertion is o(n) // - Smallest element goes at the end, so that popping it is o(1). -struct CostQueue -{ - CostQueue(uint32_t size = UINT32_MAX) : m_maxSize(size), m_pairs(MemTag::SegmentAtlasChartCandidates) {} +struct CostQueue { + CostQueue(uint32_t size = UINT32_MAX) : + m_maxSize(size), m_pairs(MemTag::SegmentAtlasChartCandidates) {} - float peekCost() const - { + float peekCost() const { return m_pairs.back().cost; } - uint32_t peekFace() const - { + uint32_t peekFace() const { return m_pairs.back().face; } - void push(float cost, uint32_t face) - { + void push(float cost, uint32_t face) { const Pair p = { cost, face }; if (m_pairs.isEmpty() || cost < peekCost()) m_pairs.push_back(p); @@ -4836,29 +4566,25 @@ struct CostQueue } } - uint32_t pop() - { + uint32_t pop() { XA_DEBUG_ASSERT(!m_pairs.isEmpty()); uint32_t f = m_pairs.back().face; m_pairs.pop_back(); return f; } - XA_INLINE void clear() - { + XA_INLINE void clear() { m_pairs.clear(); } - XA_INLINE uint32_t count() const - { + XA_INLINE uint32_t count() const { return m_pairs.size(); } private: const uint32_t m_maxSize; - struct Pair - { + struct Pair { float cost; uint32_t face; }; @@ -4866,9 +4592,9 @@ private: Array m_pairs; }; -struct Chart -{ - Chart() : faces(MemTag::SegmentAtlasChartFaces) {} +struct Chart { + Chart() : + faces(MemTag::SegmentAtlasChartFaces) {} int id = -1; Basis basis; // Best fit normal. @@ -4882,12 +4608,11 @@ struct Chart CostQueue candidates; }; -struct Atlas -{ - Atlas() : m_edgeLengths(MemTag::SegmentAtlasMeshData), m_faceAreas(MemTag::SegmentAtlasMeshData), m_faceNormals(MemTag::SegmentAtlasMeshData), m_texcoords(MemTag::SegmentAtlasMeshData), m_bestTriangles(10), m_nextPlanarRegionFace(MemTag::SegmentAtlasPlanarRegions), m_facePlanarRegionId(MemTag::SegmentAtlasPlanarRegions) {} +struct Atlas { + Atlas() : + m_edgeLengths(MemTag::SegmentAtlasMeshData), m_faceAreas(MemTag::SegmentAtlasMeshData), m_faceNormals(MemTag::SegmentAtlasMeshData), m_texcoords(MemTag::SegmentAtlasMeshData), m_bestTriangles(10), m_nextPlanarRegionFace(MemTag::SegmentAtlasPlanarRegions), m_facePlanarRegionId(MemTag::SegmentAtlasPlanarRegions) {} - ~Atlas() - { + ~Atlas() { const uint32_t chartCount = m_charts.size(); for (uint32_t i = 0; i < chartCount; i++) { m_charts[i]->~Chart(); @@ -4900,8 +4625,7 @@ struct Atlas const Array &chartFaces(uint32_t i) const { return m_charts[i]->faces; } const Basis &chartBasis(uint32_t chartIndex) const { return m_charts[chartIndex]->basis; } - void reset(uint32_t meshId, uint32_t chartGroupId, const Mesh *mesh, const ChartOptions &options) - { + void reset(uint32_t meshId, uint32_t chartGroupId, const Mesh *mesh, const ChartOptions &options) { XA_UNUSED(meshId); XA_UNUSED(chartGroupId); XA_PROFILE_START(buildAtlasInit) @@ -4995,8 +4719,7 @@ struct Atlas XA_PROFILE_END(buildAtlasInit) } - void placeSeeds(float threshold) - { + void placeSeeds(float threshold) { XA_PROFILE_START(buildAtlasPlaceSeeds) // Instead of using a predefiened number of seeds: // - Add seeds one by one, growing chart until a certain treshold. @@ -5010,8 +4733,7 @@ struct Atlas } // Returns true if any of the charts can grow more. - void growCharts(float threshold) - { + void growCharts(float threshold) { XA_PROFILE_START(buildAtlasGrowCharts) for (;;) { if (m_facesLeft == 0) @@ -5057,8 +4779,7 @@ struct Atlas XA_PROFILE_END(buildAtlasGrowCharts) } - void resetCharts() - { + void resetCharts() { XA_PROFILE_START(buildAtlasResetCharts) const uint32_t faceCount = m_mesh->faceCount(); for (uint32_t i = 0; i < faceCount; i++) @@ -5083,8 +4804,7 @@ struct Atlas XA_PROFILE_END(buildAtlasResetCharts) } - bool relocateSeeds() - { + bool relocateSeeds() { XA_PROFILE_START(buildAtlasRelocateSeeds) bool anySeedChanged = false; const uint32_t chartCount = m_charts.size(); @@ -5097,8 +4817,7 @@ struct Atlas return anySeedChanged; } - void fillHoles(float threshold) - { + void fillHoles(float threshold) { XA_PROFILE_START(buildAtlasFillHoles) while (m_facesLeft > 0) createRandomChart(threshold); @@ -5106,8 +4825,7 @@ struct Atlas } #if XA_MERGE_CHARTS - void mergeCharts() - { + void mergeCharts() { XA_PROFILE_START(buildAtlasMergeCharts) const uint32_t chartCount = m_charts.size(); // Merge charts progressively until there's none left to merge. @@ -5165,7 +4883,7 @@ struct Atlas // Merge if chart2 has a single face. // chart1 must have more than 1 face. // chart2 area must be <= 10% of chart1 area. - if (m_sharedBoundaryLengthsNoSeams[cc] > 0.0f && chart->faces.size() > 1 && chart2->faces.size() == 1 && chart2->area <= chart->area * 0.1f) + if (m_sharedBoundaryLengthsNoSeams[cc] > 0.0f && chart->faces.size() > 1 && chart2->faces.size() == 1 && chart2->area <= chart->area * 0.1f) goto merge; // Merge if chart2 has two faces (probably a quad), and chart1 bounds at least 2 of its edges. if (chart2->faces.size() == 2 && m_sharedBoundaryEdgeCountNoSeams[cc] >= 2) @@ -5173,8 +4891,8 @@ struct Atlas // Merge if chart2 is wholely inside chart1, ignoring seams. if (m_sharedBoundaryLengthsNoSeams[cc] > 0.0f && equal(m_sharedBoundaryLengthsNoSeams[cc], chart2->boundaryLength, kEpsilon)) goto merge; - if (m_sharedBoundaryLengths[cc] > 0.2f * max(0.0f, chart->boundaryLength - externalBoundaryLength) || - m_sharedBoundaryLengths[cc] > 0.75f * chart2->boundaryLength) + if (m_sharedBoundaryLengths[cc] > 0.2f * max(0.0f, chart->boundaryLength - externalBoundaryLength) || + m_sharedBoundaryLengths[cc] > 0.75f * chart2->boundaryLength) goto merge; continue; merge: @@ -5212,8 +4930,7 @@ struct Atlas #endif private: - void createRandomChart(float threshold) - { + void createRandomChart(float threshold) { Chart *chart = XA_NEW(MemTag::Default, Chart); chart->id = (int)m_charts.size(); m_charts.push_back(chart); @@ -5239,15 +4956,13 @@ private: } } - bool isChartBoundaryEdge(const Chart *chart, uint32_t edge) const - { + bool isChartBoundaryEdge(const Chart *chart, uint32_t edge) const { const uint32_t oppositeEdge = m_mesh->oppositeEdge(edge); const uint32_t oppositeFace = meshEdgeFace(oppositeEdge); return oppositeEdge == UINT32_MAX || m_faceCharts[oppositeFace] != chart->id; } - bool computeChartBasis(Chart *chart, Basis *basis) - { + bool computeChartBasis(Chart *chart, Basis *basis) { const uint32_t faceCount = chart->faces.size(); m_tempPoints.resize(chart->faces.size() * 3); for (uint32_t i = 0; i < faceCount; i++) { @@ -5258,8 +4973,7 @@ private: return Fit::computeBasis(m_tempPoints.data(), m_tempPoints.size(), basis); } - bool isFaceFlipped(uint32_t face) const - { + bool isFaceFlipped(uint32_t face) const { const Vector2 &v1 = m_texcoords[face * 3 + 0]; const Vector2 &v2 = m_texcoords[face * 3 + 1]; const Vector2 &v3 = m_texcoords[face * 3 + 2]; @@ -5267,8 +4981,7 @@ private: return parametricArea < 0.0f; } - void parameterizeChart(const Chart *chart) - { + void parameterizeChart(const Chart *chart) { const uint32_t faceCount = chart->faces.size(); for (uint32_t i = 0; i < faceCount; i++) { const uint32_t face = chart->faces[i]; @@ -5281,8 +4994,7 @@ private: } // m_faceCharts for the chart faces must be set to the chart ID. Needed to compute boundary edges. - bool isChartParameterizationValid(const Chart *chart) - { + bool isChartParameterizationValid(const Chart *chart) { const uint32_t faceCount = chart->faces.size(); // Check for flipped faces in the parameterization. OK if all are flipped. uint32_t flippedFaceCount = 0; @@ -5307,15 +5019,14 @@ private: return true; } - bool addFaceToChart(Chart *chart, uint32_t face) - { + bool addFaceToChart(Chart *chart, uint32_t face) { XA_DEBUG_ASSERT(m_faceCharts[face] == -1); const uint32_t oldFaceCount = chart->faces.size(); const bool firstFace = oldFaceCount == 0; // Append the face and any coplanar connected faces to the chart faces array. chart->faces.push_back(face); uint32_t coplanarFace = m_nextPlanarRegionFace[face]; - while (coplanarFace != face) { + while (coplanarFace != face) { XA_DEBUG_ASSERT(m_faceCharts[coplanarFace] == -1); chart->faces.push_back(coplanarFace); coplanarFace = m_nextPlanarRegionFace[coplanarFace]; @@ -5327,7 +5038,7 @@ private: // Use the first face normal. // Use any edge as the tangent vector. basis.normal = m_faceNormals[face]; - basis.tangent = normalize(m_mesh->position(m_mesh->vertexAt(face * 3 + 0)) - m_mesh->position(m_mesh->vertexAt(face * 3 + 1)), kEpsilon); + basis.tangent = normalize(m_mesh->position(m_mesh->vertexAt(face * 3 + 0)) - m_mesh->position(m_mesh->vertexAt(face * 3 + 1)), 0); basis.bitangent = cross(basis.normal, basis.tangent); } else { // Use best fit normal. @@ -5385,8 +5096,7 @@ private: } // Returns true if the seed has changed. - bool relocateSeed(Chart *chart) - { + bool relocateSeed(Chart *chart) { // Find the first N triangles that fit the proxy best. const uint32_t faceCount = chart->faces.size(); m_bestTriangles.clear(); @@ -5425,8 +5135,7 @@ private: } // Evaluate combined metric. - float evaluateCost(Chart *chart, uint32_t face) const - { + float evaluateCost(Chart *chart, uint32_t face) const { if (dot(m_faceNormals[face], chart->basis.normal) <= 0.26f) // ~75 degrees return FLT_MAX; // Estimate boundary length and area: @@ -5467,16 +5176,14 @@ private: } // Returns a value in [0-1]. - float evaluateProxyFitMetric(Chart *chart, uint32_t face) const - { + float evaluateProxyFitMetric(Chart *chart, uint32_t face) const { // All faces in coplanar regions have the same normal, can use any face. const Vector3 faceNormal = m_faceNormals[face]; // Use plane fitting metric for now: return 1 - dot(faceNormal, chart->basis.normal); // @@ normal deviations should be weighted by face area } - float evaluateRoundnessMetric(Chart *chart, float newBoundaryLength, float newChartArea) const - { + float evaluateRoundnessMetric(Chart *chart, float newBoundaryLength, float newChartArea) const { const float roundness = square(chart->boundaryLength) / chart->area; const float newBoundaryLengthSq = square(newBoundaryLength); const float newRoundness = newBoundaryLengthSq / newChartArea; @@ -5486,12 +5193,11 @@ private: return 0; } - float evaluateStraightnessMetric(Chart *chart, uint32_t firstFace) const - { + float evaluateStraightnessMetric(Chart *chart, uint32_t firstFace) const { float l_out = 0.0f, l_in = 0.0f; const uint32_t planarRegionId = m_facePlanarRegionId[firstFace]; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { const float l = m_edgeLengths[it.edge()]; if (it.isBoundary()) { @@ -5512,8 +5218,7 @@ private: return min(ratio, 0.0f); // Only use the straightness metric to close gaps. } - bool isNormalSeam(uint32_t edge) const - { + bool isNormalSeam(uint32_t edge) const { const uint32_t oppositeEdge = m_mesh->oppositeEdge(edge); if (oppositeEdge == UINT32_MAX) return false; // boundary edge @@ -5533,11 +5238,10 @@ private: return !equal(m_faceNormals[f0], m_faceNormals[f1], kNormalEpsilon); } - float evaluateNormalSeamMetric(Chart *chart, uint32_t firstFace) const - { + float evaluateNormalSeamMetric(Chart *chart, uint32_t firstFace) const { float seamFactor = 0.0f, totalLength = 0.0f; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { if (it.isBoundary()) continue; @@ -5574,11 +5278,10 @@ private: return seamFactor / totalLength; } - float evaluateTextureSeamMetric(Chart *chart, uint32_t firstFace) const - { + float evaluateTextureSeamMetric(Chart *chart, uint32_t firstFace) const { float seamLength = 0.0f, totalLength = 0.0f; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { if (it.isBoundary()) continue; @@ -5601,11 +5304,10 @@ private: return seamLength / totalLength; } - float computeArea(Chart *chart, uint32_t firstFace) const - { + float computeArea(Chart *chart, uint32_t firstFace) const { float area = chart->area; uint32_t face = firstFace; - for (;;) { + for (;;) { area += m_faceAreas[face]; face = m_nextPlanarRegionFace[face]; if (face == firstFace) @@ -5614,13 +5316,12 @@ private: return area; } - float computeBoundaryLength(Chart *chart, uint32_t firstFace) const - { + float computeBoundaryLength(Chart *chart, uint32_t firstFace) const { float boundaryLength = chart->boundaryLength; // Add new edges, subtract edges shared with the chart. const uint32_t planarRegionId = m_facePlanarRegionId[firstFace]; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { const float edgeLength = m_edgeLengths[it.edge()]; if (it.isBoundary()) { @@ -5636,11 +5337,10 @@ private: if (face == firstFace) break; } - return max(0.0f, boundaryLength); // @@ Hack! + return max(0.0f, boundaryLength); // @@ Hack! } - bool mergeChart(Chart *owner, Chart *chart, float sharedBoundaryLength) - { + bool mergeChart(Chart *owner, Chart *chart, float sharedBoundaryLength) { const uint32_t oldOwnerFaceCount = owner->faces.size(); const uint32_t chartFaceCount = chart->faces.size(); owner->faces.push_back(chart->faces); @@ -5706,11 +5406,10 @@ private: namespace param { -class JacobiPreconditioner -{ +class JacobiPreconditioner { public: - JacobiPreconditioner(const sparse::Matrix &M, bool symmetric) : m_inverseDiagonal(M.width()) - { + JacobiPreconditioner(const sparse::Matrix &M, bool symmetric) : + m_inverseDiagonal(M.width()) { XA_ASSERT(M.isSquare()); for (uint32_t x = 0; x < M.width(); x++) { float elem = M.getCoefficient(x, x); @@ -5723,8 +5422,7 @@ public: } } - void apply(const FullVector &x, FullVector &y) const - { + void apply(const FullVector &x, FullVector &y) const { XA_DEBUG_ASSERT(x.dimension() == m_inverseDiagonal.dimension()); XA_DEBUG_ASSERT(y.dimension() == m_inverseDiagonal.dimension()); // @@ Wrap vector component-wise product into a separate function. @@ -5739,12 +5437,10 @@ private: }; // Linear solvers. -class Solver -{ +class Solver { public: // Solve the symmetric system: At·A·x = At·b - static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) - { + static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) { XA_DEBUG_ASSERT(A.width() == x.dimension()); XA_DEBUG_ASSERT(A.height() == b.dimension()); XA_DEBUG_ASSERT(A.height() >= A.width()); // @@ If height == width we could solve it directly... @@ -5759,8 +5455,7 @@ public: } // See section 10.4.3 in: Mesh Parameterization: Theory and Practice, Siggraph Course Notes, August 2007 - static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, const uint32_t *lockedParameters, uint32_t lockedCount, float epsilon = 1e-5f) - { + static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, const uint32_t *lockedParameters, uint32_t lockedCount, float epsilon = 1e-5f) { XA_DEBUG_ASSERT(A.width() == x.dimension()); XA_DEBUG_ASSERT(A.height() == b.dimension()); XA_DEBUG_ASSERT(A.height() >= A.width() - lockedCount); @@ -5859,18 +5554,17 @@ private: * **/ // Conjugate gradient with preconditioner. - static bool ConjugateGradientSolver(const JacobiPreconditioner &preconditioner, const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon) - { - XA_DEBUG_ASSERT( A.isSquare() ); - XA_DEBUG_ASSERT( A.width() == b.dimension() ); - XA_DEBUG_ASSERT( A.width() == x.dimension() ); + static bool ConjugateGradientSolver(const JacobiPreconditioner &preconditioner, const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon) { + XA_DEBUG_ASSERT(A.isSquare()); + XA_DEBUG_ASSERT(A.width() == b.dimension()); + XA_DEBUG_ASSERT(A.width() == x.dimension()); int i = 0; const int D = A.width(); - const int i_max = 4 * D; // Convergence should be linear, but in some cases, it's not. - FullVector r(D); // residual - FullVector p(D); // search direction - FullVector q(D); // - FullVector s(D); // preconditioned + const int i_max = 4 * D; // Convergence should be linear, but in some cases, it's not. + FullVector r(D); // residual + FullVector p(D); // search direction + FullVector q(D); // + FullVector s(D); // preconditioned float delta_0; float delta_old; float delta_new; @@ -5896,7 +5590,7 @@ private: // x = alfa·p + x sparse::saxpy(alpha, p, x); if ((i & 31) == 0) { // recompute r after 32 steps - // r = b - A·x + // r = b - A·x sparse::copy(b, r); sparse::sgemv(-1, A, x, 1, r); } else { @@ -5906,7 +5600,7 @@ private: // s = M^-1 · r preconditioner.apply(r, s); delta_old = delta_new; - delta_new = sparse::dot( r, s ); + delta_new = sparse::dot(r, s); beta = delta_new / delta_old; // p = s + beta·p sparse::scal(beta, p); @@ -5915,8 +5609,7 @@ private: return delta_new <= epsilon * epsilon * delta_0; } - static bool SymmetricSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) - { + static bool SymmetricSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) { XA_DEBUG_ASSERT(A.height() == A.width()); XA_DEBUG_ASSERT(A.height() == b.dimension()); XA_DEBUG_ASSERT(b.dimension() == x.dimension()); @@ -5926,8 +5619,7 @@ private: }; // Fast sweep in 3 directions -static bool findApproximateDiameterVertices(Mesh *mesh, uint32_t *a, uint32_t *b) -{ +static bool findApproximateDiameterVertices(Mesh *mesh, uint32_t *a, uint32_t *b) { XA_DEBUG_ASSERT(a != nullptr); XA_DEBUG_ASSERT(b != nullptr); const uint32_t vertexCount = mesh->vertexCount(); @@ -5984,28 +5676,24 @@ static bool findApproximateDiameterVertices(Mesh *mesh, uint32_t *a, uint32_t *b // Conformal relations from Brecht Van Lommel (based on ABF): -static float vec_angle_cos(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) -{ +static float vec_angle_cos(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) { Vector3 d1 = v1 - v2; Vector3 d2 = v3 - v2; return clamp(dot(d1, d2) / (length(d1) * length(d2)), -1.0f, 1.0f); } -static float vec_angle(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) -{ +static float vec_angle(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) { float dot = vec_angle_cos(v1, v2, v3); return acosf(dot); } -static void triangle_angles(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3, float *a1, float *a2, float *a3) -{ +static void triangle_angles(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3, float *a1, float *a2, float *a3) { *a1 = vec_angle(v3, v1, v2); *a2 = vec_angle(v1, v2, v3); *a3 = kPi - *a2 - *a1; } -static void setup_abf_relations(sparse::Matrix &A, int row, int id0, int id1, int id2, const Vector3 &p0, const Vector3 &p1, const Vector3 &p2) -{ +static void setup_abf_relations(sparse::Matrix &A, int row, int id0, int id1, int id2, const Vector3 &p0, const Vector3 &p1, const Vector3 &p2) { // @@ IC: Wouldn't it be more accurate to return cos and compute 1-cos^2? // It does indeed seem to be a little bit more robust. // @@ Need to revisit this more carefully! @@ -6055,8 +5743,7 @@ static void setup_abf_relations(sparse::Matrix &A, int row, int id0, int id1, in A.setCoefficient(v2_id, 2 * row + 1, 1); } -static bool computeLeastSquaresConformalMap(Mesh *mesh) -{ +static bool computeLeastSquaresConformalMap(Mesh *mesh) { // For this to work properly, mesh should not have colocals that have the same // attributes, unless you want the vertices to actually have different texcoords. const uint32_t vertexCount = mesh->vertexCount(); @@ -6114,10 +5801,8 @@ static bool computeLeastSquaresConformalMap(Mesh *mesh) } #if XA_RECOMPUTE_CHARTS -struct PiecewiseParam -{ - void reset(const Mesh *mesh, uint32_t faceCount) - { +struct PiecewiseParam { + void reset(const Mesh *mesh, uint32_t faceCount) { m_mesh = mesh; m_faceCount = faceCount; const uint32_t vertexCount = m_mesh->vertexCount(); @@ -6134,8 +5819,7 @@ struct PiecewiseParam ConstArrayView chartFaces() const { return m_patch; } const Vector2 *texcoords() const { return m_texcoords.data(); } - bool computeChart() - { + bool computeChart() { m_patch.clear(); m_faceInvalid.zeroOutMemory(); m_faceInPatch.zeroOutMemory(); @@ -6242,8 +5926,7 @@ struct PiecewiseParam } private: - struct Candidate - { + struct Candidate { uint32_t face, vertex; uint32_t next; // The next candidate with the same vertex. Vector2 position; @@ -6253,10 +5936,12 @@ private: float patchVertexOrient; }; - struct CandidateIterator - { - CandidateIterator(Array &candidates, uint32_t first) : m_candidates(candidates), m_current(first) {} - void advance() { if (m_current != UINT32_MAX) m_current = m_candidates[m_current].next; } + struct CandidateIterator { + CandidateIterator(Array &candidates, uint32_t first) : + m_candidates(candidates), m_current(first) {} + void advance() { + if (m_current != UINT32_MAX) m_current = m_candidates[m_current].next; + } bool isDone() const { return m_current == UINT32_MAX; } Candidate ¤t() { return m_candidates[m_current]; } @@ -6277,8 +5962,7 @@ private: UniformGrid2 m_boundaryGrid; // Find candidate faces on the patch front. - void findCandidates() - { + void findCandidates() { m_candidates.clear(); m_faceInCandidates.zeroOutMemory(); for (uint32_t i = 0; i < m_patch.size(); i++) { @@ -6335,8 +6019,7 @@ private: } } - void addCandidateFace(uint32_t patchEdge, float patchVertexOrient, uint32_t face, uint32_t edge, uint32_t freeVertex) - { + void addCandidateFace(uint32_t patchEdge, float patchVertexOrient, uint32_t face, uint32_t edge, uint32_t freeVertex) { Vector2 texcoords[3]; orthoProjectFace(face, texcoords); // Find corresponding vertices between the patch edge and candidate edge. @@ -6412,8 +6095,7 @@ private: m_faceInCandidates.set(face); } - void orthoProjectFace(uint32_t face, Vector2 *texcoords) const - { + void orthoProjectFace(uint32_t face, Vector2 *texcoords) const { const Vector3 normal = m_mesh->computeFaceNormal(face); const Vector3 tangent = normalize(m_mesh->position(m_mesh->vertexAt(face * 3 + 1)) - m_mesh->position(m_mesh->vertexAt(face * 3 + 0)), kEpsilon); const Vector3 bitangent = cross(normal, tangent); @@ -6423,16 +6105,14 @@ private: } } - float parametricArea(const Vector2 *texcoords) const - { + float parametricArea(const Vector2 *texcoords) const { const Vector2 &v1 = texcoords[0]; const Vector2 &v2 = texcoords[1]; const Vector2 &v3 = texcoords[2]; return ((v2.x - v1.x) * (v3.y - v1.y) - (v3.x - v1.x) * (v2.y - v1.y)) * 0.5f; } - float computeStretch(Vector3 p1, Vector3 p2, Vector3 p3, Vector2 t1, Vector2 t2, Vector2 t3) const - { + float computeStretch(Vector3 p1, Vector3 p2, Vector3 p3, Vector2 t1, Vector2 t2, Vector2 t3) const { float parametricArea = ((t2.y - t1.y) * (t3.x - t1.x) - (t3.y - t1.y) * (t2.x - t1.x)) * 0.5f; if (isZero(parametricArea, kAreaEpsilon)) return FLT_MAX; @@ -6446,16 +6126,14 @@ private: } // Return value is positive if the point is one side of the edge, negative if on the other side. - float orientToEdge(Vector2 edgeVertex0, Vector2 edgeVertex1, Vector2 point) const - { + float orientToEdge(Vector2 edgeVertex0, Vector2 edgeVertex1, Vector2 point) const { return (edgeVertex0.x - point.x) * (edgeVertex1.y - point.y) - (edgeVertex0.y - point.y) * (edgeVertex1.x - point.x); } }; #endif // Estimate quality of existing parameterization. -struct Quality -{ +struct Quality { // computeBoundaryIntersection bool boundaryIntersection = false; @@ -6472,8 +6150,7 @@ struct Quality float conformalMetric = 0.0f; float authalicMetric = 0.0f; - void computeBoundaryIntersection(const Mesh *mesh, UniformGrid2 &boundaryGrid) - { + void computeBoundaryIntersection(const Mesh *mesh, UniformGrid2 &boundaryGrid) { const Array &boundaryEdges = mesh->boundaryEdges(); const uint32_t boundaryEdgeCount = boundaryEdges.size(); boundaryGrid.reset(mesh->texcoords(), mesh->indices(), boundaryEdgeCount); @@ -6489,8 +6166,7 @@ struct Quality #endif } - void computeFlippedFaces(const Mesh *mesh, uint32_t faceCount, Array *flippedFaces) - { + void computeFlippedFaces(const Mesh *mesh, uint32_t faceCount, Array *flippedFaces) { totalTriangleCount = flippedTriangleCount = zeroAreaTriangleCount = 0; if (flippedFaces) flippedFaces->clear(); @@ -6525,8 +6201,7 @@ struct Quality flippedFaces->clear(); flippedTriangleCount = 0; } - if (flippedTriangleCount > totalTriangleCount / 2) - { + if (flippedTriangleCount > totalTriangleCount / 2) { // If more than half the triangles are flipped, reverse the flipped / not flipped classification. flippedTriangleCount = totalTriangleCount - flippedTriangleCount; if (flippedFaces) { @@ -6548,8 +6223,7 @@ struct Quality } } - void computeMetrics(const Mesh *mesh, uint32_t faceCount) - { + void computeMetrics(const Mesh *mesh, uint32_t faceCount) { totalGeometricArea = totalParametricArea = 0.0f; stretchMetric = maxStretchMetric = conformalMetric = authalicMetric = 0.0f; for (uint32_t f = 0; f < faceCount; f++) { @@ -6580,7 +6254,7 @@ struct Quality const float a = dot(Ss, Ss); // E const float b = dot(Ss, St); // F const float c = dot(St, St); // G - // Compute eigen-values of the first fundamental form: + // Compute eigen-values of the first fundamental form: const float sigma1 = sqrtf(0.5f * max(0.0f, a + c - sqrtf(square(a - c) + 4 * square(b)))); // gamma uppercase, min eigenvalue. const float sigma2 = sqrtf(0.5f * max(0.0f, a + c + sqrtf(square(a - c) + 4 * square(b)))); // gamma lowercase, max eigenvalue. XA_ASSERT(sigma2 > sigma1 || equal(sigma1, sigma2, kEpsilon)); @@ -6611,37 +6285,33 @@ struct Quality if (totalGeometricArea > 0.0f) { const float normFactor = sqrtf(totalParametricArea / totalGeometricArea); stretchMetric = sqrtf(stretchMetric / totalGeometricArea) * normFactor; - maxStretchMetric *= normFactor; + maxStretchMetric *= normFactor; conformalMetric = sqrtf(conformalMetric / totalGeometricArea); authalicMetric = sqrtf(authalicMetric / totalGeometricArea); } } }; -struct ChartWarningFlags -{ - enum Enum - { - CloseHolesFailed = 1<<1, - FixTJunctionsDuplicatedEdge = 1<<2, - FixTJunctionsFailed = 1<<3, - TriangulateDuplicatedEdge = 1<<4, +struct ChartWarningFlags { + enum Enum { + CloseHolesFailed = 1 << 1, + FixTJunctionsDuplicatedEdge = 1 << 2, + FixTJunctionsFailed = 1 << 3, + TriangulateDuplicatedEdge = 1 << 4, }; }; -struct ChartCtorBuffers -{ +struct ChartCtorBuffers { Array chartMeshIndices; Array unifiedMeshIndices; Array boundaryLoops; }; /// A chart is a connected set of faces with a certain topology (usually a disk). -class Chart -{ +class Chart { public: - Chart(ChartCtorBuffers &buffers, const Basis &basis, ConstArrayView faces, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : m_basis(basis), m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::LSCM), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) - { + Chart(ChartCtorBuffers &buffers, const Basis &basis, ConstArrayView faces, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : + m_basis(basis), m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::LSCM), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) { XA_UNUSED(meshId); XA_UNUSED(chartGroupId); XA_UNUSED(chartId); @@ -6780,8 +6450,8 @@ public: } #if XA_RECOMPUTE_CHARTS - Chart(ChartCtorBuffers &buffers, const Chart *parent, const Mesh *parentMesh, ConstArrayView faces, const Vector2 *texcoords, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::Piecewise), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) - { + Chart(ChartCtorBuffers &buffers, const Chart *parent, const Mesh *parentMesh, ConstArrayView faces, const Vector2 *texcoords, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : + m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::Piecewise), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) { XA_UNUSED(meshId); XA_UNUSED(chartGroupId); XA_UNUSED(chartId); @@ -6846,8 +6516,7 @@ public: } #endif - ~Chart() - { + ~Chart() { if (m_mesh) { m_mesh->~Mesh(); XA_FREE(m_mesh); @@ -6880,8 +6549,7 @@ public: const Mesh *unmodifiedUnifiedMesh() const { return m_unmodifiedUnifiedMesh; } uint32_t mapChartVertexToOriginalVertex(uint32_t i) const { return m_chartToOriginalMap[i]; } - void evaluateOrthoQuality(UniformGrid2 &boundaryGrid) - { + void evaluateOrthoQuality(UniformGrid2 &boundaryGrid) { XA_PROFILE_START(parameterizeChartsEvaluateQuality) m_quality.computeBoundaryIntersection(m_unifiedMesh, boundaryGrid); m_quality.computeFlippedFaces(m_unifiedMesh, m_initialFaceCount, nullptr); @@ -6892,8 +6560,7 @@ public: m_type = ChartType::Ortho; } - void evaluateQuality(UniformGrid2 &boundaryGrid) - { + void evaluateQuality(UniformGrid2 &boundaryGrid) { XA_PROFILE_START(parameterizeChartsEvaluateQuality) m_quality.computeBoundaryIntersection(m_unifiedMesh, boundaryGrid); #if XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION @@ -6906,15 +6573,13 @@ public: } // Transfer parameterization from unified mesh to chart mesh. - void transferParameterization() - { + void transferParameterization() { const uint32_t vertexCount = m_mesh->vertexCount(); for (uint32_t v = 0; v < vertexCount; v++) m_mesh->texcoord(v) = m_unifiedMesh->texcoord(m_chartToUnifiedMap[v]); } - Vector2 computeParametricBounds() const - { + Vector2 computeParametricBounds() const { Vector2 minCorner(FLT_MAX, FLT_MAX); Vector2 maxCorner(-FLT_MAX, -FLT_MAX); const uint32_t vertexCount = m_mesh->vertexCount(); @@ -6949,8 +6614,7 @@ private: #endif }; -struct CreateChartTaskArgs -{ +struct CreateChartTaskArgs { const Mesh *mesh; const Basis *basis; ConstArrayView faces; @@ -6961,27 +6625,23 @@ struct CreateChartTaskArgs Chart **chart; }; -static void runCreateChartTask(void *userData) -{ +static void runCreateChartTask(void *userData) { XA_PROFILE_START(createChartMeshesThread) auto args = (CreateChartTaskArgs *)userData; *(args->chart) = XA_NEW_ARGS(MemTag::Default, Chart, args->chartBuffers->get(), *(args->basis), args->faces, args->mesh, args->meshId, args->chartGroupId, args->chartId); XA_PROFILE_END(createChartMeshesThread) } -struct ParameterizeChartTaskArgs -{ +struct ParameterizeChartTaskArgs { Chart *chart; ParameterizeFunc func; ThreadLocal *boundaryGrid; }; -static void runParameterizeChartTask(void *userData) -{ +static void runParameterizeChartTask(void *userData) { auto args = (ParameterizeChartTaskArgs *)userData; Mesh *mesh = args->chart->unifiedMesh(); - XA_PROFILE_START(parameterizeChartsOrthogonal) - { + XA_PROFILE_START(parameterizeChartsOrthogonal) { // Project vertices to plane. const uint32_t vertexCount = mesh->vertexCount(); const Basis &basis = args->chart->basis(); @@ -7006,11 +6666,10 @@ static void runParameterizeChartTask(void *userData) } // Set of charts corresponding to mesh faces in the same face group. -class ChartGroup -{ +class ChartGroup { public: - ChartGroup(uint32_t id, const Mesh *sourceMesh, uint16_t faceGroup) : m_sourceId(sourceMesh->id()), m_id(id), m_isVertexMap(faceGroup == Mesh::kInvalidFaceGroup), m_paramAddedChartsCount(0), m_paramDeletedChartsCount(0) - { + ChartGroup(uint32_t id, const Mesh *sourceMesh, uint16_t faceGroup) : + m_sourceId(sourceMesh->id()), m_id(id), m_isVertexMap(faceGroup == Mesh::kInvalidFaceGroup), m_paramAddedChartsCount(0), m_paramDeletedChartsCount(0) { // Create new mesh from the source mesh, using faces that belong to this group. const uint32_t sourceFaceCount = sourceMesh->faceCount(); if (!m_isVertexMap) { @@ -7072,8 +6731,7 @@ public: #endif } - ~ChartGroup() - { + ~ChartGroup() { m_mesh->~Mesh(); XA_FREE(m_mesh); for (uint32_t i = 0; i < m_charts.size(); i++) { @@ -7151,8 +6809,7 @@ public: - emphasize roundness metrics to prevent those cases. - If interior self-overlaps: preserve boundary parameterization and use mean-value map. */ - void computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, segment::Atlas &atlas, ThreadLocal *chartBuffers) - { + void computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, segment::Atlas &atlas, ThreadLocal *chartBuffers) { m_chartOptions = options; // This function may be called multiple times, so destroy existing charts. for (uint32_t i = 0; i < m_charts.size(); i++) { @@ -7222,7 +6879,7 @@ public: #if XA_RECOMPUTE_CHARTS void parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ThreadLocal *boundaryGrid, ThreadLocal *chartBuffers, ThreadLocal *piecewiseParam) #else - void parameterizeCharts(TaskScheduler* taskScheduler, ParameterizeFunc func, ThreadLocal* boundaryGrid, ThreadLocal* /*chartBuffers*/) + void parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ThreadLocal *boundaryGrid, ThreadLocal * /*chartBuffers*/) #endif { m_paramAddedChartsCount = 0; @@ -7316,8 +6973,7 @@ public: } private: - void buildAtlas(segment::Atlas &atlas, const ChartOptions &options) - { + void buildAtlas(segment::Atlas &atlas, const ChartOptions &options) { if (atlas.facesLeft() == 0) return; // Create initial charts greedely. @@ -7347,8 +7003,7 @@ private: XA_DEBUG_ASSERT(atlas.facesLeft() == 0); } - void removeChart(const Chart *chart) - { + void removeChart(const Chart *chart) { for (uint32_t i = 0; i < m_charts.size(); i++) { if (m_charts[i] == chart) { m_charts.removeAt(i); @@ -7368,24 +7023,21 @@ private: uint32_t m_paramDeletedChartsCount; // Number of charts with invalid parameterizations that were deleted, after charts were recomputed. }; -struct CreateChartGroupTaskArgs -{ +struct CreateChartGroupTaskArgs { uint16_t faceGroup; uint32_t groupId; const Mesh *mesh; ChartGroup **chartGroup; }; -static void runCreateChartGroupTask(void *userData) -{ +static void runCreateChartGroupTask(void *userData) { XA_PROFILE_START(addMeshCreateChartGroupsThread) auto args = (CreateChartGroupTaskArgs *)userData; *(args->chartGroup) = XA_NEW_ARGS(MemTag::Default, ChartGroup, args->groupId, args->mesh, args->faceGroup); XA_PROFILE_END(addMeshCreateChartGroupsThread) } -struct ComputeChartsTaskArgs -{ +struct ComputeChartsTaskArgs { TaskScheduler *taskScheduler; ChartGroup *chartGroup; ThreadLocal *atlas; @@ -7394,8 +7046,7 @@ struct ComputeChartsTaskArgs Progress *progress; }; -static void runComputeChartsJob(void *userData) -{ +static void runComputeChartsJob(void *userData) { auto args = (ComputeChartsTaskArgs *)userData; if (args->progress->cancel) return; @@ -7406,8 +7057,7 @@ static void runComputeChartsJob(void *userData) args->progress->update(); } -struct ParameterizeChartsTaskArgs -{ +struct ParameterizeChartsTaskArgs { TaskScheduler *taskScheduler; ChartGroup *chartGroup; ParameterizeFunc func; @@ -7419,8 +7069,7 @@ struct ParameterizeChartsTaskArgs Progress *progress; }; -static void runParameterizeChartsJob(void *userData) -{ +static void runParameterizeChartsJob(void *userData) { auto args = (ParameterizeChartsTaskArgs *)userData; if (args->progress->cancel) return; @@ -7436,13 +7085,12 @@ static void runParameterizeChartsJob(void *userData) } /// An atlas is a set of chart groups. -class Atlas -{ +class Atlas { public: - Atlas() : m_meshCount(0), m_chartsComputed(false), m_chartsParameterized(false) {} + Atlas() : + m_meshCount(0), m_chartsComputed(false), m_chartsParameterized(false) {} - ~Atlas() - { + ~Atlas() { for (uint32_t i = 0; i < m_chartGroups.size(); i++) { m_chartGroups[i]->~ChartGroup(); XA_FREE(m_chartGroups[i]); @@ -7454,8 +7102,7 @@ public: uint32_t chartGroupCount() const { return m_chartGroups.size(); } const ChartGroup *chartGroupAt(uint32_t index) const { return m_chartGroups[index]; } - uint32_t chartGroupCount(uint32_t mesh) const - { + uint32_t chartGroupCount(uint32_t mesh) const { uint32_t count = 0; for (uint32_t i = 0; i < m_chartGroups.size(); i++) { if (m_chartGroupSourceMeshes[i] == mesh) @@ -7464,8 +7111,7 @@ public: return count; } - const ChartGroup *chartGroupAt(uint32_t mesh, uint32_t group) const - { + const ChartGroup *chartGroupAt(uint32_t mesh, uint32_t group) const { for (uint32_t c = 0; c < m_chartGroups.size(); c++) { if (m_chartGroupSourceMeshes[c] != mesh) continue; @@ -7477,8 +7123,7 @@ public: } // This function is thread safe. - void addMesh(TaskScheduler *taskScheduler, const Mesh *mesh) - { + void addMesh(TaskScheduler *taskScheduler, const Mesh *mesh) { // Create one chart group per face group. // If there's any ignored faces in the mesh, create an extra face group for that (vertex map). // Chart group creation is slow since it copies a chunk of the source mesh, so use tasks. @@ -7513,8 +7158,7 @@ public: // Chart id/index is determined by depth-first hierarchy of mesh -> chart group -> chart. // For chart index to be consistent here, chart groups needs to sorted by mesh index. Since addMesh is called by multithreaded tasks, order is indeterminate, so chart groups need to be explicitly sorted after all meshes are added. - void sortChartGroups() - { + void sortChartGroups() { Array oldChartGroups; oldChartGroups.resize(m_chartGroups.size()); memcpy(oldChartGroups.data(), m_chartGroups.data(), sizeof(ChartGroup *) * m_chartGroups.size()); @@ -7533,8 +7177,7 @@ public: } } - bool computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, ProgressFunc progressFunc, void *progressUserData) - { + bool computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, ProgressFunc progressFunc, void *progressUserData) { m_chartsComputed = false; m_chartsParameterized = false; // Ignore vertex maps. @@ -7582,8 +7225,7 @@ public: return true; } - bool parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ProgressFunc progressFunc, void *progressUserData) - { + bool parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ProgressFunc progressFunc, void *progressUserData) { m_chartsParameterized = false; // Ignore vertex maps. uint32_t chartGroupCount = 0; @@ -7643,17 +7285,15 @@ private: namespace pack { -class AtlasImage -{ +class AtlasImage { public: - AtlasImage(uint32_t width, uint32_t height) : m_width(width), m_height(height) - { + AtlasImage(uint32_t width, uint32_t height) : + m_width(width), m_height(height) { m_data.resize(m_width * m_height); memset(m_data.data(), 0, sizeof(uint32_t) * m_data.size()); } - void resize(uint32_t width, uint32_t height) - { + void resize(uint32_t width, uint32_t height) { Array data; data.resize(width * height); memset(data.data(), 0, sizeof(uint32_t) * data.size()); @@ -7664,8 +7304,7 @@ public: data.moveTo(m_data); } - void addChart(uint32_t chartIndex, const BitImage *image, const BitImage *imageBilinear, const BitImage *imagePadding, int atlas_w, int atlas_h, int offset_x, int offset_y) - { + void addChart(uint32_t chartIndex, const BitImage *image, const BitImage *imageBilinear, const BitImage *imagePadding, int atlas_w, int atlas_h, int offset_x, int offset_y) { const int w = image->width(); const int h = image->height(); for (int y = 0; y < h; y++) { @@ -7691,15 +7330,13 @@ public: } } - void copyTo(uint32_t *dest, uint32_t destWidth, uint32_t destHeight, int padding) const - { + void copyTo(uint32_t *dest, uint32_t destWidth, uint32_t destHeight, int padding) const { for (uint32_t y = 0; y < destHeight; y++) memcpy(&dest[y * destWidth], &m_data[padding + (y + padding) * m_width], destWidth * sizeof(uint32_t)); } #if XA_DEBUG_EXPORT_ATLAS_IMAGES - void writeTga(const char *filename, uint32_t width, uint32_t height) const - { + void writeTga(const char *filename, uint32_t width, uint32_t height) const { Array image; image.resize(width * height * 3); for (uint32_t y = 0; y < height; y++) { @@ -7741,8 +7378,7 @@ private: Array m_data; }; -struct Chart -{ +struct Chart { int32_t atlasIndex; uint32_t material; uint32_t indexCount; @@ -7764,15 +7400,13 @@ struct Chart uint32_t uniqueVertexCount() const { return uniqueVertices.isEmpty() ? vertexCount : uniqueVertices.size(); } }; -struct AddChartTaskArgs -{ +struct AddChartTaskArgs { ThreadLocal *boundingBox; param::Chart *paramChart; Chart *chart; // out }; -static void runAddChartTask(void *userData) -{ +static void runAddChartTask(void *userData) { XA_PROFILE_START(packChartsAddChartsThread) auto args = (AddChartTaskArgs *)userData; param::Chart *paramChart = args->paramChart; @@ -7811,10 +7445,8 @@ static void runAddChartTask(void *userData) XA_PROFILE_END(packChartsAddChartsThread) } -struct Atlas -{ - ~Atlas() - { +struct Atlas { + ~Atlas() { for (uint32_t i = 0; i < m_atlasImages.size(); i++) { m_atlasImages[i]->~AtlasImage(); XA_FREE(m_atlasImages[i]); @@ -7838,8 +7470,7 @@ struct Atlas const Array &getImages() const { return m_atlasImages; } float getUtilization(uint32_t atlas) const { return m_utilization[atlas]; } - void addCharts(TaskScheduler *taskScheduler, param::Atlas *paramAtlas) - { + void addCharts(TaskScheduler *taskScheduler, param::Atlas *paramAtlas) { // Count charts. uint32_t chartCount = 0; const uint32_t chartGroupsCount = paramAtlas->chartGroupCount(); @@ -7880,8 +7511,7 @@ struct Atlas m_charts[i] = taskArgs[i].chart; } - void addUvMeshCharts(UvMeshInstance *mesh) - { + void addUvMeshCharts(UvMeshInstance *mesh) { BitArray vertexUsed(mesh->texcoords.size()); BoundingBox2D boundingBox; for (uint32_t c = 0; c < mesh->mesh->charts.size(); c++) { @@ -7942,8 +7572,7 @@ struct Atlas } // Pack charts in the smallest possible rectangle. - bool packCharts(const PackOptions &options, ProgressFunc progressFunc, void *progressUserData) - { + bool packCharts(const PackOptions &options, ProgressFunc progressFunc, void *progressUserData) { if (progressFunc) { if (!progressFunc(ProgressCategory::PackCharts, 0, progressUserData)) return false; @@ -8178,8 +7807,7 @@ struct Atlas int best_x = 0, best_y = 0; int best_cw = 0, best_ch = 0; int best_r = 0; - for (;;) - { + for (;;) { bool firstChartInBitImage = false; XA_UNUSED(firstChartInBitImage); if (currentAtlas + 1 > m_bitImages.size()) { @@ -8212,8 +7840,7 @@ struct Atlas if (best_x + best_cw > atlasSizes[currentAtlas].x || best_y + best_ch > atlasSizes[currentAtlas].y) { for (uint32_t j = 0; j < chartStartPositions.size(); j++) chartStartPositions[j] = Vector2i(0, 0); - } - else { + } else { chartStartPositions[currentAtlas] = Vector2i(best_x, best_y); } } @@ -8312,8 +7939,7 @@ struct Atlas } if (m_utilization.size() > 1) { XA_PRINT(" %u: %f%% utilization\n", i, m_utilization[i] * 100.0f); - } - else { + } else { XA_PRINT(" %f%% utilization\n", m_utilization[i] * 100.0f); } } @@ -8336,16 +7962,14 @@ private: // is occupied at this point. At the end we have many small charts and a large atlas with sparse holes. Finding those holes randomly is slow. A better approach would be to // start stacking large charts as if they were tetris pieces. Once charts get small try to place them randomly. It may be interesting to try a intermediate strategy, first try // along one axis and then try exhaustively along that axis. - bool findChartLocation(const Vector2i &startPosition, bool bruteForce, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) - { + bool findChartLocation(const Vector2i &startPosition, bool bruteForce, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) { const int attempts = 4096; if (bruteForce || attempts >= w * h) return findChartLocation_bruteForce(startPosition, atlasBitImage, chartBitImage, chartBitImageRotated, w, h, best_x, best_y, best_w, best_h, best_r, blockAligned, maxResolution, allowRotate); return findChartLocation_random(atlasBitImage, chartBitImage, chartBitImageRotated, w, h, best_x, best_y, best_w, best_h, best_r, attempts, blockAligned, maxResolution, allowRotate); } - bool findChartLocation_bruteForce(const Vector2i &startPosition, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) - { + bool findChartLocation_bruteForce(const Vector2i &startPosition, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) { const int stepSize = blockAligned ? 4 : 1; int best_metric = INT_MAX; // Try two different orientations. @@ -8390,8 +8014,7 @@ private: return best_metric != INT_MAX; } - bool findChartLocation_random(const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, int minTrialCount, bool blockAligned, uint32_t maxResolution, bool allowRotate) - { + bool findChartLocation_random(const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, int minTrialCount, bool blockAligned, uint32_t maxResolution, bool allowRotate) { bool result = false; const int BLOCK_SIZE = 4; int best_metric = INT_MAX; @@ -8446,8 +8069,7 @@ private: return result; } - void addChart(BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int atlas_w, int atlas_h, int offset_x, int offset_y, int r) - { + void addChart(BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int atlas_w, int atlas_h, int offset_x, int offset_y, int r) { XA_DEBUG_ASSERT(r == 0 || r == 1); const BitImage *image = r == 0 ? chartBitImage : chartBitImageRotated; const int w = image->width(); @@ -8470,8 +8092,7 @@ private: } } - void bilinearExpand(const Chart *chart, BitImage *source, BitImage *dest, BitImage *destRotated, UniformGrid2 &boundaryEdgeGrid) const - { + void bilinearExpand(const Chart *chart, BitImage *source, BitImage *dest, BitImage *destRotated, UniformGrid2 &boundaryEdgeGrid) const { boundaryEdgeGrid.reset(chart->vertices, chart->indices); if (chart->boundaryEdges) { const uint32_t edgeCount = chart->boundaryEdges->size(); @@ -8526,13 +8147,11 @@ private: } } - struct DrawTriangleCallbackArgs - { + struct DrawTriangleCallbackArgs { BitImage *chartBitImage, *chartBitImageRotated; }; - static bool drawTriangleCallback(void *param, int x, int y) - { + static bool drawTriangleCallback(void *param, int x, int y) { auto args = (DrawTriangleCallbackArgs *)param; args->chartBitImage->set(x, y); if (args->chartBitImageRotated) @@ -8554,8 +8173,7 @@ private: } // namespace pack } // namespace internal -struct Context -{ +struct Context { Atlas atlas; uint32_t meshCount = 0; internal::Progress *addMeshProgress = nullptr; @@ -8568,16 +8186,14 @@ struct Context internal::Array uvMeshInstances; }; -Atlas *Create() -{ +Atlas *Create() { Context *ctx = XA_NEW(internal::MemTag::Default, Context); memset(&ctx->atlas, 0, sizeof(Atlas)); ctx->taskScheduler = XA_NEW(internal::MemTag::Default, internal::TaskScheduler); return &ctx->atlas; } -static void DestroyOutputMeshes(Context *ctx) -{ +static void DestroyOutputMeshes(Context *ctx) { if (!ctx->atlas.meshes) return; for (int i = 0; i < (int)ctx->atlas.meshCount; i++) { @@ -8598,8 +8214,7 @@ static void DestroyOutputMeshes(Context *ctx) ctx->atlas.meshes = nullptr; } -void Destroy(Atlas *atlas) -{ +void Destroy(Atlas *atlas) { XA_DEBUG_ASSERT(atlas); Context *ctx = (Context *)atlas; if (atlas->utilization) @@ -8634,14 +8249,12 @@ void Destroy(Atlas *atlas) #endif } -struct AddMeshTaskArgs -{ +struct AddMeshTaskArgs { Context *ctx; internal::Mesh *mesh; }; -static void runAddMeshTask(void *userData) -{ +static void runAddMeshTask(void *userData) { XA_PROFILE_START(addMeshThread) auto args = (AddMeshTaskArgs *)userData; // Responsible for freeing this. internal::Mesh *mesh = args->mesh; @@ -8710,37 +8323,32 @@ cleanup: XA_PROFILE_END(addMeshThread) } -static internal::Vector3 DecodePosition(const MeshDecl &meshDecl, uint32_t index) -{ +static internal::Vector3 DecodePosition(const MeshDecl &meshDecl, uint32_t index) { XA_DEBUG_ASSERT(meshDecl.vertexPositionData); XA_DEBUG_ASSERT(meshDecl.vertexPositionStride > 0); return *((const internal::Vector3 *)&((const uint8_t *)meshDecl.vertexPositionData)[meshDecl.vertexPositionStride * index]); } -static internal::Vector3 DecodeNormal(const MeshDecl &meshDecl, uint32_t index) -{ +static internal::Vector3 DecodeNormal(const MeshDecl &meshDecl, uint32_t index) { XA_DEBUG_ASSERT(meshDecl.vertexNormalData); XA_DEBUG_ASSERT(meshDecl.vertexNormalStride > 0); return *((const internal::Vector3 *)&((const uint8_t *)meshDecl.vertexNormalData)[meshDecl.vertexNormalStride * index]); } -static internal::Vector2 DecodeUv(const MeshDecl &meshDecl, uint32_t index) -{ +static internal::Vector2 DecodeUv(const MeshDecl &meshDecl, uint32_t index) { XA_DEBUG_ASSERT(meshDecl.vertexUvData); XA_DEBUG_ASSERT(meshDecl.vertexUvStride > 0); return *((const internal::Vector2 *)&((const uint8_t *)meshDecl.vertexUvData)[meshDecl.vertexUvStride * index]); } -static uint32_t DecodeIndex(IndexFormat::Enum format, const void *indexData, int32_t offset, uint32_t i) -{ +static uint32_t DecodeIndex(IndexFormat::Enum format, const void *indexData, int32_t offset, uint32_t i) { XA_DEBUG_ASSERT(indexData); if (format == IndexFormat::UInt16) return uint16_t((int32_t)((const uint16_t *)indexData)[i] + offset); return uint32_t((int32_t)((const uint32_t *)indexData)[i] + offset); } -AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t meshCountHint) -{ +AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t meshCountHint) { XA_DEBUG_ASSERT(atlas); if (!atlas) { XA_PRINT_WARNING("AddMesh: atlas is null.\n"); @@ -8758,8 +8366,7 @@ AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t mesh // Don't know how many times AddMesh will be called, so progress needs to adjusted each time. if (!ctx->addMeshProgress) { ctx->addMeshProgress = XA_NEW_ARGS(internal::MemTag::Default, internal::Progress, ProgressCategory::AddMesh, ctx->progressFunc, ctx->progressUserData, 1); - } - else { + } else { ctx->addMeshProgress->setMaxValue(internal::max(ctx->meshCount + 1, meshCountHint)); } XA_PROFILE_START(addMeshCopyData) @@ -8875,8 +8482,7 @@ AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t mesh return AddMeshError::Success; } -void AddMeshJoin(Atlas *atlas) -{ +void AddMeshJoin(Atlas *atlas) { XA_DEBUG_ASSERT(atlas); if (!atlas) { XA_PRINT_WARNING("AddMeshJoin: atlas is null.\n"); @@ -8904,19 +8510,19 @@ void AddMeshJoin(Atlas *atlas) XA_PRINT_MEM_USAGE } -struct EdgeKey -{ +struct EdgeKey { EdgeKey() {} - EdgeKey(const EdgeKey &k) : v0(k.v0), v1(k.v1) {} - EdgeKey(uint32_t v0, uint32_t v1) : v0(v0), v1(v1) {} + EdgeKey(const EdgeKey &k) : + v0(k.v0), v1(k.v1) {} + EdgeKey(uint32_t v0, uint32_t v1) : + v0(v0), v1(v1) {} bool operator==(const EdgeKey &k) const { return v0 == k.v0 && v1 == k.v1; } uint32_t v0; uint32_t v1; }; -AddMeshError::Enum AddUvMesh(Atlas *atlas, const UvMeshDecl &decl) -{ +AddMeshError::Enum AddUvMesh(Atlas *atlas, const UvMeshDecl &decl) { XA_DEBUG_ASSERT(atlas); if (!atlas) { XA_PRINT_WARNING("AddUvMesh: atlas is null.\n"); @@ -9026,8 +8632,7 @@ AddMeshError::Enum AddUvMesh(Atlas *atlas, const UvMeshDecl &decl) return AddMeshError::Success; } -void ComputeCharts(Atlas *atlas, ChartOptions chartOptions) -{ +void ComputeCharts(Atlas *atlas, ChartOptions chartOptions) { if (!atlas) { XA_PRINT_WARNING("ComputeCharts: atlas is null.\n"); return; @@ -9100,8 +8705,7 @@ void ComputeCharts(Atlas *atlas, ChartOptions chartOptions) XA_PRINT_MEM_USAGE } -void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) -{ +void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) { if (!atlas) { XA_PRINT_WARNING("ParameterizeCharts: atlas is null.\n"); return; @@ -9132,7 +8736,7 @@ void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) XA_PROFILE_START(parameterizeChartsReal) if (!ctx->paramAtlas.parameterizeCharts(ctx->taskScheduler, func, ctx->progressFunc, ctx->progressUserData)) { XA_PRINT(" Cancelled by user\n"); - return; + return; } XA_PROFILE_END(parameterizeChartsReal) uint32_t chartCount = 0, orthoChartsCount = 0, planarChartsCount = 0, lscmChartsCount = 0, piecewiseChartsCount = 0, chartsAddedCount = 0, chartsDeletedCount = 0; @@ -9234,8 +8838,7 @@ void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) XA_PRINT_MEM_USAGE } -void PackCharts(Atlas *atlas, PackOptions packOptions) -{ +void PackCharts(Atlas *atlas, PackOptions packOptions) { // Validate arguments and context state. if (!atlas) { XA_PRINT_WARNING("PackCharts: atlas is null.\n"); @@ -9277,8 +8880,7 @@ void PackCharts(Atlas *atlas, PackOptions packOptions) if (!ctx->uvMeshInstances.isEmpty()) { for (uint32_t i = 0; i < ctx->uvMeshInstances.size(); i++) packAtlas.addUvMeshCharts(ctx->uvMeshInstances[i]); - } - else + } else packAtlas.addCharts(ctx->taskScheduler, &ctx->paramAtlas); XA_PROFILE_END(packChartsAddCharts) XA_PROFILE_START(packCharts) @@ -9479,8 +9081,7 @@ void PackCharts(Atlas *atlas, PackOptions packOptions) XA_PRINT_MEM_USAGE } -void Generate(Atlas *atlas, ChartOptions chartOptions, ParameterizeFunc paramFunc, PackOptions packOptions) -{ +void Generate(Atlas *atlas, ChartOptions chartOptions, ParameterizeFunc paramFunc, PackOptions packOptions) { if (!atlas) { XA_PRINT_WARNING("Generate: atlas is null.\n"); return; @@ -9499,8 +9100,7 @@ void Generate(Atlas *atlas, ChartOptions chartOptions, ParameterizeFunc paramFun PackCharts(atlas, packOptions); } -void SetProgressCallback(Atlas *atlas, ProgressFunc progressFunc, void *progressUserData) -{ +void SetProgressCallback(Atlas *atlas, ProgressFunc progressFunc, void *progressUserData) { if (!atlas) { XA_PRINT_WARNING("SetProgressCallback: atlas is null.\n"); return; @@ -9510,20 +9110,17 @@ void SetProgressCallback(Atlas *atlas, ProgressFunc progressFunc, void *progress ctx->progressUserData = progressUserData; } -void SetAlloc(ReallocFunc reallocFunc, FreeFunc freeFunc) -{ +void SetAlloc(ReallocFunc reallocFunc, FreeFunc freeFunc) { internal::s_realloc = reallocFunc; internal::s_free = freeFunc; } -void SetPrint(PrintFunc print, bool verbose) -{ +void SetPrint(PrintFunc print, bool verbose) { internal::s_print = print; internal::s_printVerbose = verbose; } -const char *StringForEnum(AddMeshError::Enum error) -{ +const char *StringForEnum(AddMeshError::Enum error) { if (error == AddMeshError::Error) return "Unspecified error"; if (error == AddMeshError::IndexOutOfRange) @@ -9533,8 +9130,7 @@ const char *StringForEnum(AddMeshError::Enum error) return "Success"; } -const char *StringForEnum(ProgressCategory::Enum category) -{ +const char *StringForEnum(ProgressCategory::Enum category) { if (category == ProgressCategory::AddMesh) return "Adding mesh(es)"; if (category == ProgressCategory::ComputeCharts)