@@ -26,6 +26,8 @@ namespace dnnl {
26
26
namespace impl {
27
27
namespace cpu {
28
28
namespace x64 {
29
+
30
+ #define avx512_gemm_available () false
29
31
30
32
int jit_avx2_kernel_sgemm_kern::next_acc (int idx, int um, int un) const {
31
33
while (!(((idx / unroll_n_) < std::max (1 , um / nelt_per_vecreg_))
@@ -36,7 +38,7 @@ int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const {
36
38
37
39
void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload (
38
40
int um, int un, int k_idx, int n_idx) {
39
- if (!mayiuse (avx512_core )) {
41
+ if (!avx512_gemm_available ( )) {
40
42
if ((n_idx == 0 ) && (k_idx == 0 ) && (un == unroll_n_) && (um != 16 )) {
41
43
prefetcht0 (ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
42
44
offb_ += 16 ;
@@ -46,7 +48,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload(
46
48
47
49
void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA (
48
50
int um, int un, int k_idx, int n_idx, int m_idx) {
49
- if (!mayiuse (avx512_core )) {
51
+ if (!avx512_gemm_available ( )) {
50
52
if ((um == 16 ) || (un < unroll_n_)) {
51
53
if ((k_idx + m_idx + n_idx) == 0 ) {
52
54
prefetcht0 (ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
@@ -63,7 +65,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA(
63
65
64
66
void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA (
65
67
int um, int un, int k_idx, int n_idx, int m_idx) {
66
- if (mayiuse (avx512_core )) {
68
+ if (avx512_gemm_available ( )) {
67
69
if ((um < unroll_m_) && (m_idx == 0 )) {
68
70
if (((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 0 ) && (n_idx % 6 == 0 ))
69
71
|| ((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 1 )
@@ -87,7 +89,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA(
87
89
88
90
void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload (
89
91
int um, int un, int k_idx, int n_idx) {
90
- if (!mayiuse (avx512_core )) {
92
+ if (!avx512_gemm_available ( )) {
91
93
if ((um == unroll_m_) && (un == 2 )) {
92
94
if (k_idx % 3 == 0 ) {
93
95
if (n_idx == 1 ) {
@@ -111,7 +113,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload(
111
113
112
114
void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA (
113
115
int k_idx, int n_idx, int m_idx) {
114
- if (mayiuse (avx512_core )) {
116
+ if (avx512_gemm_available ( )) {
115
117
if (((m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) * unroll_m_reg_)
116
118
== 0 )
117
119
&& (n_idx == 1 )) {
@@ -126,7 +128,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA(
126
128
127
129
void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA (
128
130
int um, int un, int k_idx, int n_idx, int m_idx) {
129
- if (!mayiuse (avx512_core )) {
131
+ if (!avx512_gemm_available ( )) {
130
132
if ((um == unroll_m_) && (un == unroll_n_)) {
131
133
if (((k_idx == 0 ) && (n_idx % 2 == 1 ) && (m_idx == 0 ))
132
134
|| ((k_idx == 1 ) && (n_idx == 2 ) && (m_idx == 0 ))
@@ -160,7 +162,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA(
160
162
161
163
void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload (
162
164
int um, int un, int k_idx, int n_idx) {
163
- if (mayiuse (avx512_core )) {
165
+ if (avx512_gemm_available ( )) {
164
166
if (um == unroll_m_) {
165
167
if (n_idx == std::min (1 , un - 1 )) {
166
168
if (k_idx == unroll_k_ - 1 )
@@ -173,7 +175,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload(
173
175
}
174
176
175
177
void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop (int um) {
176
- if (mayiuse (avx512_core )) {
178
+ if (avx512_gemm_available ( )) {
177
179
if (um < unroll_m_) {
178
180
prefetchw (ptr[CO2_ + elt_size_ * 0 ]);
179
181
prefetchw (ptr[CO2_ + elt_size_ * 8 ]);
@@ -228,7 +230,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
228
230
mov (C_, ptr[rsp + get_size_of_abi_save_regs () + C_off]);
229
231
mov (LDC_, ptr[rsp + get_size_of_abi_save_regs () + LDC_off]);
230
232
231
- if (mayiuse (avx512_core )) {
233
+ if (avx512_gemm_available ( )) {
232
234
for (i = zmm_acc_idx_; i < unroll_m_reg_ * unroll_n_ + zmm_acc_idx_;
233
235
i++)
234
236
vpxorq (Xbyak::Zmm (i), Xbyak::Zmm (i), Xbyak::Zmm (i));
@@ -267,7 +269,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
267
269
add (AA_, A_);
268
270
mov (CO1_, C_);
269
271
270
- if ((unroll_x == unroll_m_) || (!mayiuse (avx512_core )))
272
+ if ((unroll_x == unroll_m_) || (!avx512_gemm_available ( )))
271
273
lea (CO2_, ptr[C_ + LDC_ * 2 ]);
272
274
273
275
add (C_, unroll_x * elt_size_);
@@ -292,12 +294,12 @@ void jit_avx2_kernel_sgemm_kern::generate() {
292
294
T_NEAR);
293
295
}
294
296
295
- if (!mayiuse (avx512_core ))
297
+ if (!avx512_gemm_available ( ))
296
298
prefetcht2 (ptr[AA_ - addr_off_ * elt_size_]);
297
299
298
300
switch (unroll_x) {
299
301
case 8 :
300
- if (mayiuse (avx512_core )) {
302
+ if (avx512_gemm_available ( )) {
301
303
loop<Xbyak::Zmm, Xbyak::Zmm, Xbyak::Address, Xbyak::Xmm,
302
304
Xbyak::Operand>(unroll_x, unroll_y,
303
305
&Xbyak::CodeGenerator::vbroadcastf64x4,
@@ -319,7 +321,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
319
321
320
322
break ;
321
323
case 4 :
322
- if (mayiuse (avx512_core )) {
324
+ if (avx512_gemm_available ( )) {
323
325
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Address, Xbyak::Xmm,
324
326
Xbyak::Operand>(unroll_x, unroll_y,
325
327
&Xbyak::CodeGenerator::vbroadcastf32x4,
@@ -340,7 +342,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
340
342
341
343
break ;
342
344
case 2 :
343
- if (mayiuse (avx512_core )) {
345
+ if (avx512_gemm_available ( )) {
344
346
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Operand, Xbyak::Xmm,
345
347
Xbyak::Operand>(unroll_x, unroll_y,
346
348
&Xbyak::CodeGenerator::vbroadcastsd,
@@ -357,7 +359,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
357
359
&Xbyak::CodeGenerator::vmovsd);
358
360
break ;
359
361
case 1 :
360
- if (mayiuse (avx512_core )) {
362
+ if (avx512_gemm_available ( )) {
361
363
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
362
364
Xbyak::Operand>(unroll_x, unroll_y,
363
365
&Xbyak::CodeGenerator::vbroadcastss,
@@ -377,7 +379,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
377
379
378
380
break ;
379
381
default :
380
- if (mayiuse (avx512_core )) {
382
+ if (avx512_gemm_available ( )) {
381
383
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
382
384
Xbyak::Operand>(unroll_x, unroll_y,
383
385
&Xbyak::CodeGenerator::vmovups,
@@ -400,7 +402,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
400
402
break ;
401
403
}
402
404
403
- if (mayiuse (avx512_core )) {
405
+ if (avx512_gemm_available ( )) {
404
406
sub (AA_, -16 * elt_size_);
405
407
} else {
406
408
if ((unroll_y != unroll_n_) || (unroll_x <= 4 )) {
0 commit comments