Skip to content

Commit 44bec63

Browse files
dmitry-gorokhovazhai219
authored andcommitted
[FORK][FEATURE] Added custom vesrion of JIT DW FP32/BF16 Convolution with 5D input support
1 parent 081ba59 commit 44bec63

15 files changed

+2552
-3
lines changed

include/oneapi/dnnl/dnnl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,7 @@ struct memory : public handle<dnnl_memory_t> {
14541454
BAcde16b16a = dnnl_BAcde16b16a,
14551455
BAcde16a16b = dnnl_BAcde16a16b,
14561456
aBdec32b = dnnl_aBdec32b,
1457+
Abcdef8a = dnnl_Abcdef8a,
14571458
Abcdef16a = dnnl_Abcdef16a,
14581459
Abcdef32a = dnnl_Abcdef32a,
14591460
Acdb32a = dnnl_Acdb32a,
@@ -1701,6 +1702,7 @@ struct memory : public handle<dnnl_memory_t> {
17011702
IOdhw16i16o = dnnl_IOdhw16i16o,
17021703
gIOhw16i16o = dnnl_gIOhw16i16o,
17031704
gOhwi32o = dnnl_gOhwi32o,
1705+
Goidhw8g = dnnl_Goidhw8g,
17041706
Goidhw16g = dnnl_Goidhw16g,
17051707
IOw16o16i = dnnl_IOw16o16i,
17061708
OIw16i16o = dnnl_OIw16i16o,

include/oneapi/dnnl/dnnl_types.h

+2
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ typedef enum {
398398
dnnl_aCBdef16c16b,
399399
dnnl_aBdefc4b,
400400
dnnl_aBdefc8b,
401+
dnnl_Abcdef8a,
401402
dnnl_Abcdef16a,
402403
dnnl_Abcdef32a,
403404
dnnl_aBedc16b,
@@ -1620,6 +1621,7 @@ typedef enum {
16201621
dnnl_gIOdhw8o16i2o = dnnl_aCBdef8b16c2b,
16211622
dnnl_gOIdhw8o8i = dnnl_aBCdef8b8c,
16221623
dnnl_gOIdhw8o4i = dnnl_aBCdef8b4c,
1624+
dnnl_Goidhw8g = dnnl_Abcdef8a,
16231625
dnnl_Goidhw16g = dnnl_Abcdef16a,
16241626
dnnl_Goidhw32g = dnnl_Abcdef32a,
16251627
dnnl_gOIdhw2i4o2i = dnnl_aBCdef2c4b2c,

src/common/c_types_map.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,7 @@ const format_tag_t ABcd40a32b = dnnl_ABcd40a32b;
677677
const format_tag_t ABcde40a32b = dnnl_ABcde40a32b;
678678
const format_tag_t BAcde16b16a = dnnl_BAcde16b16a;
679679
const format_tag_t aBdec32b = dnnl_aBdec32b;
680+
const format_tag_t Abcdef8a = dnnl_Abcdef8a;
680681
const format_tag_t Abcdef16a = dnnl_Abcdef16a;
681682
const format_tag_t Abcdef32a = dnnl_Abcdef32a;
682683
const format_tag_t Acdb32a = dnnl_Acdb32a;
@@ -1174,6 +1175,7 @@ const format_tag_t IOhw16i16o = dnnl_IOhw16i16o;
11741175
const format_tag_t Ohwi32o = dnnl_Ohwi32o;
11751176
const format_tag_t gIOhw16i16o = dnnl_gIOhw16i16o;
11761177
const format_tag_t gOhwi32o = dnnl_gOhwi32o;
1178+
const format_tag_t Goidhw8g = dnnl_Goidhw8g;
11771179
const format_tag_t Goidhw16g = dnnl_Goidhw16g;
11781180
const format_tag_t IOw16o16i = dnnl_IOw16o16i;
11791181
const format_tag_t IOw16i16o = dnnl_IOw16i16o;

src/common/dnnl_debug_autogenerated.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
302302
if (v == dnnl_aCBdef16c16b) return "aCBdef16c16b";
303303
if (v == dnnl_aBdefc4b) return "aBdefc4b";
304304
if (v == dnnl_aBdefc8b) return "aBdefc8b";
305+
if (v == dnnl_Abcdef8a) return "Abcdef8a";
305306
if (v == dnnl_Abcdef16a) return "Abcdef16a";
306307
if (v == dnnl_Abcdef32a) return "Abcdef32a";
307308
if (v == dnnl_aBedc16b) return "aBedc16b";
@@ -1397,6 +1398,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
13971398
if (v == dnnl_gIOdhw8o16i2o) return "gIOdhw8o16i2o";
13981399
if (v == dnnl_gOIdhw8o8i) return "gOIdhw8o8i";
13991400
if (v == dnnl_gOIdhw8o4i) return "gOIdhw8o4i";
1401+
if (v == dnnl_Goidhw8g) return "Goidhw8g";
14001402
if (v == dnnl_Goidhw16g) return "Goidhw16g";
14011403
if (v == dnnl_Goidhw32g) return "Goidhw32g";
14021404
if (v == dnnl_gOIdhw2i4o2i) return "gOIdhw2i4o2i";

src/common/memory.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ status_t dnnl_memory_create(memory_t **memory, const memory_desc_t *md,
164164
: memory_flags_t::use_runtime_ptr;
165165
void *handle_ptr = (handle == DNNL_MEMORY_ALLOCATE) ? nullptr : handle;
166166
auto _memory = new memory_t(engine, md, flags, handle_ptr);
167-
if (_memory == nullptr) return out_of_memory;
167+
if (_memory == nullptr)
168+
return out_of_memory;
168169
if (_memory->memory_storage() == nullptr) {
169170
delete _memory;
170171
return out_of_memory;

src/common/memory_desc_wrapper.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ status_t memory_desc_wrapper::compute_blocking(
586586
C(aBdec32b, {0, 1, 3, 4, 2}, {32}, {1});
587587
C(aCBdef16c16b, {0, 2, 1, 3, 4, 5}, {16, 16}, {2, 1});
588588
C(aCBdef16b16c, {0, 2, 1, 3, 4, 5}, {16, 16}, {1, 2});
589+
C(Abcdef8a, {0, 1, 2, 3, 4, 5}, {8}, {0});
589590
C(Abcdef16a, {0, 1, 2, 3, 4, 5}, {16}, {0});
590591
C(Abcdef32a, {0, 1, 2, 3, 4, 5}, {32}, {0});
591592
C(aCBd16c16b, {0, 2, 1, 3}, {16, 16}, {2, 1});

src/common/tag_traits.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ DECL_TRAITS(aBCde2b8c8b2c, _BC, _2b8c8b2c, 5);
795795
DECL_TRAITS(aBdec32b, _B, _32b, 5);
796796
DECL_TRAITS(aCBdef16c16b, _BC, _16c16b, 6);
797797
DECL_TRAITS(aCBdef16b16c, _BC, _16b16c, 6);
798+
DECL_TRAITS(Abcdef8a, _A, _8a, 6);
798799
DECL_TRAITS(Abcdef16a, _A, _16a, 6);
799800
DECL_TRAITS(aCBd16c16b, _BC, _16c16b, 4);
800801
DECL_TRAITS(aCBde16c16b, _BC, _16c16b, 4);

src/cpu/cpu_convolution_list.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "cpu/x64/jit_sse41_1x1_convolution.hpp"
5353
#include "cpu/x64/jit_sse41_convolution.hpp"
5454
#include "cpu/x64/jit_uni_dw_convolution.hpp"
55+
#include "cpu/x64/jit_uni_fork_dw_convolution.hpp"
5556
#include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp"
5657
#include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp"
5758
using namespace dnnl::impl::cpu::x64;
@@ -126,13 +127,16 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
126127
CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core>)
127128
CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core>)
128129
CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_fwd_t)
130+
CPU_INSTANCE_AVX512(jit_avx512_common_fork_dw_convolution_fwd_t)
129131
CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_fwd_f32_t)
130132
CPU_INSTANCE_AVX512(jit_avx512_common_convolution_fwd_t<f32>)
131133
CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_fwd_t)
132134
CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2>)
133135
CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2>)
136+
CPU_INSTANCE_AVX2(jit_avx2_fork_dw_convolution_fwd_t)
134137
CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_fwd_t)
135138
CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_fwd_t)
139+
CPU_INSTANCE_SSE41(jit_sse41_fork_dw_convolution_fwd_t)
136140
CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t)
137141
CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t)
138142
CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t)
@@ -166,6 +170,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
166170
CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_bf16>)
167171
CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_bf16>)
168172
CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t<avx512_core, bf16, f32>)
173+
CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_fwd_t<avx512_core, bf16, f32>)
169174
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t<f32>)
170175
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t)
171176
CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<f32>)
@@ -184,6 +189,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
184189
CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_bf16>)
185190
CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_bf16>)
186191
CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t<avx512_core, bf16, bf16>)
192+
CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_fwd_t<avx512_core, bf16, bf16>)
187193
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t<bf16>)
188194
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t)
189195
CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<bf16>)
@@ -246,13 +252,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
246252
CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core>)
247253
CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t<avx512_core>)
248254
CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_data_t)
255+
CPU_INSTANCE_AVX512(jit_avx512_common_fork_dw_convolution_bwd_data_t)
249256
CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_data_f32_t)
250257
CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_data_t<f32>)
251258
CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2>)
252259
CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t<avx2>)
253260
CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_data_t)
261+
CPU_INSTANCE_AVX2(jit_avx2_fork_dw_convolution_bwd_data_t)
254262
CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_data_t)
255-
CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_data_t)
263+
CPU_INSTANCE_SSE41(jit_sse41_fork_dw_convolution_bwd_data_t)
256264
CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_data_t)
257265
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t<sve_512,data_type::f32>)
258266
CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_data_f32_t)
@@ -271,6 +279,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
271279
CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_bf16>)
272280
CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t<avx512_core_bf16>)
273281
CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16, f32>)
282+
CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_bwd_data_t<avx512_core, bf16, f32>)
274283
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t<f32>)
275284
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t)
276285
CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t<f32>)
@@ -287,6 +296,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
287296
CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_bf16>)
288297
CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t<avx512_core_bf16>)
289298
CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16, bf16>)
299+
CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_bwd_data_t<avx512_core, bf16, bf16>)
290300
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t<bf16>)
291301
CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t)
292302
CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t<bf16>)
@@ -783,4 +793,4 @@ const impl_list_item_t *get_convolution_impl_list(
783793

784794
} // namespace cpu
785795
} // namespace impl
786-
} // namespace dnnl
796+
} // namespace dnnl

0 commit comments

Comments
 (0)