52
52
#include " cpu/x64/jit_sse41_1x1_convolution.hpp"
53
53
#include " cpu/x64/jit_sse41_convolution.hpp"
54
54
#include " cpu/x64/jit_uni_dw_convolution.hpp"
55
+ #include " cpu/x64/jit_uni_fork_dw_convolution.hpp"
55
56
#include " cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp"
56
57
#include " cpu/x64/jit_uni_x8s8s32x_convolution.hpp"
57
58
using namespace dnnl ::impl::cpu::x64;
@@ -126,13 +127,16 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
126
127
CPU_INSTANCE_AVX512 (brgemm_1x1_convolution_fwd_t <avx512_core>)
127
128
CPU_INSTANCE_AVX512 (brgemm_convolution_fwd_t <avx512_core>)
128
129
CPU_INSTANCE_AVX512 (jit_avx512_common_dw_convolution_fwd_t )
130
+ CPU_INSTANCE_AVX512 (jit_avx512_common_fork_dw_convolution_fwd_t )
129
131
CPU_INSTANCE_AVX512 (jit_avx512_common_1x1_convolution_fwd_f32_t )
130
132
CPU_INSTANCE_AVX512 (jit_avx512_common_convolution_fwd_t <f32>)
131
133
CPU_INSTANCE_AVX2 (jit_avx2_dw_convolution_fwd_t )
132
134
CPU_INSTANCE_AVX2 (brgemm_1x1_convolution_fwd_t <avx2>)
133
135
CPU_INSTANCE_AVX2 (brgemm_convolution_fwd_t <avx2>)
136
+ CPU_INSTANCE_AVX2 (jit_avx2_fork_dw_convolution_fwd_t )
134
137
CPU_INSTANCE_AVX2 (jit_avx2_1x1_convolution_fwd_t )
135
138
CPU_INSTANCE_SSE41 (jit_sse41_dw_convolution_fwd_t )
139
+ CPU_INSTANCE_SSE41 (jit_sse41_fork_dw_convolution_fwd_t )
136
140
CPU_INSTANCE_SSE41 (jit_sse41_1x1_convolution_fwd_t )
137
141
CPU_INSTANCE_AVX2 (jit_avx2_convolution_fwd_t )
138
142
CPU_INSTANCE_SSE41 (jit_sse41_convolution_fwd_t )
@@ -166,6 +170,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
166
170
CPU_INSTANCE_AVX512 (brgemm_1x1_convolution_fwd_t <avx512_core_bf16>)
167
171
CPU_INSTANCE_AVX512 (brgemm_convolution_fwd_t <avx512_core_bf16>)
168
172
CPU_INSTANCE_AVX512 (jit_uni_dw_convolution_fwd_t <avx512_core, bf16, f32>)
173
+ CPU_INSTANCE_AVX512 (jit_uni_fork_dw_convolution_fwd_t <avx512_core, bf16, f32>)
169
174
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_1x1_convolution_fwd_t <f32>)
170
175
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_convolution_fwd_t )
171
176
CPU_INSTANCE_AVX512 (gemm_bf16_convolution_fwd_t <f32>)
@@ -184,6 +189,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
184
189
CPU_INSTANCE_AVX512 (brgemm_1x1_convolution_fwd_t <avx512_core_bf16>)
185
190
CPU_INSTANCE_AVX512 (brgemm_convolution_fwd_t <avx512_core_bf16>)
186
191
CPU_INSTANCE_AVX512 (jit_uni_dw_convolution_fwd_t <avx512_core, bf16, bf16>)
192
+ CPU_INSTANCE_AVX512 (jit_uni_fork_dw_convolution_fwd_t <avx512_core, bf16, bf16>)
187
193
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_1x1_convolution_fwd_t <bf16>)
188
194
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_convolution_fwd_t )
189
195
CPU_INSTANCE_AVX512 (gemm_bf16_convolution_fwd_t <bf16>)
@@ -246,13 +252,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
246
252
CPU_INSTANCE_AVX512 (brgemm_convolution_bwd_t <avx512_core>)
247
253
CPU_INSTANCE_AVX512 (brgemm_convolution_bwd_strided_t <avx512_core>)
248
254
CPU_INSTANCE_AVX512 (jit_avx512_common_dw_convolution_bwd_data_t )
255
+ CPU_INSTANCE_AVX512 (jit_avx512_common_fork_dw_convolution_bwd_data_t )
249
256
CPU_INSTANCE_AVX512 (jit_avx512_common_1x1_convolution_bwd_data_f32_t )
250
257
CPU_INSTANCE_AVX512 (jit_avx512_common_convolution_bwd_data_t <f32>)
251
258
CPU_INSTANCE_AVX2 (brgemm_convolution_bwd_t <avx2>)
252
259
CPU_INSTANCE_AVX2 (brgemm_convolution_bwd_strided_t <avx2>)
253
260
CPU_INSTANCE_AVX2 (jit_avx2_dw_convolution_bwd_data_t )
261
+ CPU_INSTANCE_AVX2 (jit_avx2_fork_dw_convolution_bwd_data_t )
254
262
CPU_INSTANCE_AVX2 (jit_avx2_1x1_convolution_bwd_data_t )
255
- CPU_INSTANCE_SSE41 (jit_sse41_dw_convolution_bwd_data_t )
263
+ CPU_INSTANCE_SSE41 (jit_sse41_fork_dw_convolution_bwd_data_t )
256
264
CPU_INSTANCE_AVX2 (jit_avx2_convolution_bwd_data_t )
257
265
CPU_INSTANCE_AARCH64 (jit_uni_dw_convolution_bwd_data_t <sve_512,data_type::f32>)
258
266
CPU_INSTANCE_AARCH64 (jit_sve_512_1x1_convolution_bwd_data_f32_t )
@@ -271,6 +279,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
271
279
CPU_INSTANCE_AVX512 (brgemm_convolution_bwd_t <avx512_core_bf16>)
272
280
CPU_INSTANCE_AVX512 (brgemm_convolution_bwd_strided_t <avx512_core_bf16>)
273
281
CPU_INSTANCE_AVX512 (jit_uni_dw_convolution_bwd_data_t <avx512_core, bf16, f32>)
282
+ CPU_INSTANCE_AVX512 (jit_uni_fork_dw_convolution_bwd_data_t <avx512_core, bf16, f32>)
274
283
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_1x1_convolution_bwd_data_t <f32>)
275
284
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_convolution_bwd_data_t )
276
285
CPU_INSTANCE_AVX512 (gemm_bf16_convolution_bwd_data_t <f32>)
@@ -287,6 +296,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
287
296
CPU_INSTANCE_AVX512 (brgemm_convolution_bwd_t <avx512_core_bf16>)
288
297
CPU_INSTANCE_AVX512 (brgemm_convolution_bwd_strided_t <avx512_core_bf16>)
289
298
CPU_INSTANCE_AVX512 (jit_uni_dw_convolution_bwd_data_t <avx512_core, bf16, bf16>)
299
+ CPU_INSTANCE_AVX512 (jit_uni_fork_dw_convolution_bwd_data_t <avx512_core, bf16, bf16>)
290
300
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_1x1_convolution_bwd_data_t <bf16>)
291
301
CPU_INSTANCE_AVX512 (jit_avx512_core_bf16_convolution_bwd_data_t )
292
302
CPU_INSTANCE_AVX512 (gemm_bf16_convolution_bwd_data_t <bf16>)
@@ -783,4 +793,4 @@ const impl_list_item_t *get_convolution_impl_list(
783
793
784
794
} // namespace cpu
785
795
} // namespace impl
786
- } // namespace dnnl
796
+ } // namespace dnnl
0 commit comments