Skip to content

Commit 55aba95

Browse files
committed
[FORK][WA]Fall back avx512 gemm to avx2 gemm when __BUILD_GEMM_AVX512 is false.
[FORK][FEATURE] cpu: remove gemm legacy on avx512.
1 parent a9e6db9 commit 55aba95

8 files changed

+148
-129
lines changed

src/cpu/gemm/gemm.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
241241
if (*M == 0 || *N == 0 || *K == 0) return dnnl_success;
242242

243243
#if DNNL_X64 && !__BUILD_GEMM_NONE
244-
bool use_jit = mayiuse(avx512_core);
244+
bool use_jit = avx512_gemm_available();
245245
bool use_s8u8 = true
246246
&& utils::everyone_is(0, *ao, *bo) // so far a requirement
247247
&& IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41));
@@ -299,7 +299,7 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,
299299
bfloat16_t *dummy_bo = nullptr;
300300
float *dummy_co = nullptr;
301301

302-
if (mayiuse(avx512_core)) {
302+
if (avx512_gemm_available()) {
303303
auto status = gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha,
304304
(const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B,
305305
ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false);

src/cpu/gemm/gemm.hpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@
3535
#define __BUILD_GEMM_AVX2 __BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2
3636
#define __BUILD_GEMM_SSE41 __BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41
3737
#define __BUILD_GEMM_NONE BUILD_GEMM_KERNELS_NONE
38+
39+
#if __BUILD_GEMM_AVX512
40+
#define avx512_gemm_available() mayiuse(avx512_core)
41+
#define avx512_amx_gemm_available() mayiuse(avx512_core_amx)
42+
#define avx512_bf16_gemm_available() mayiuse(avx512_core_bf16)
43+
#define avx512_vnni_gemm_available() mayiuse(avx512_core_vnni)
44+
#define avx512_bf16_ymm_gemm_available() mayiuse(avx512_core_bf16_ymm)
45+
#else
46+
#define avx512_gemm_available() false
47+
#define avx512_amx_gemm_available() false
48+
#define avx512_bf16_gemm_available() false
49+
#define avx512_vnni_gemm_available() false
50+
#define avx512_bf16_ymm_gemm_available() false
51+
#endif
52+
3853
#else
3954
#define __BUILD_GEMM_AMX 0
4055
#define __BUILD_GEMM_AVX512 0
@@ -91,9 +106,9 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,
91106
#if !defined(USE_MKL_IGEMM) && defined(DNNL_X64)
92107
#define IGEMM_S8U8S32_ISA_STR \
93108
JIT_IMPL_NAME_HELPER(IGEMM_S8U8S32_IMPL_STR ":", \
94-
mayiuse(avx512_core_vnni) \
109+
avx512_vnni_gemm_available() \
95110
? avx512_core_vnni \
96-
: (mayiuse(avx512_core) ? avx512_core : isa_undef), \
111+
: (avx512_gemm_available() ? avx512_core : isa_undef), \
97112
"")
98113
#else
99114
#define IGEMM_S8U8S32_ISA_STR IGEMM_S8U8S32_IMPL_STR

src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp

+19-17
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ namespace dnnl {
2626
namespace impl {
2727
namespace cpu {
2828
namespace x64 {
29+
30+
#define avx512_gemm_available() false
2931

3032
int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const {
3133
while (!(((idx / unroll_n_) < std::max(1, um / nelt_per_vecreg_))
@@ -36,7 +38,7 @@ int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const {
3638

3739
void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload(
3840
int um, int un, int k_idx, int n_idx) {
39-
if (!mayiuse(avx512_core)) {
41+
if (!avx512_gemm_available()) {
4042
if ((n_idx == 0) && (k_idx == 0) && (un == unroll_n_) && (um != 16)) {
4143
prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
4244
offb_ += 16;
@@ -46,7 +48,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload(
4648

4749
void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA(
4850
int um, int un, int k_idx, int n_idx, int m_idx) {
49-
if (!mayiuse(avx512_core)) {
51+
if (!avx512_gemm_available()) {
5052
if ((um == 16) || (un < unroll_n_)) {
5153
if ((k_idx + m_idx + n_idx) == 0) {
5254
prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
@@ -63,7 +65,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA(
6365

6466
void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA(
6567
int um, int un, int k_idx, int n_idx, int m_idx) {
66-
if (mayiuse(avx512_core)) {
68+
if (avx512_gemm_available()) {
6769
if ((um < unroll_m_) && (m_idx == 0)) {
6870
if (((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 0) && (n_idx % 6 == 0))
6971
|| ((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 1)
@@ -87,7 +89,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA(
8789

8890
void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload(
8991
int um, int un, int k_idx, int n_idx) {
90-
if (!mayiuse(avx512_core)) {
92+
if (!avx512_gemm_available()) {
9193
if ((um == unroll_m_) && (un == 2)) {
9294
if (k_idx % 3 == 0) {
9395
if (n_idx == 1) {
@@ -111,7 +113,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload(
111113

112114
void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA(
113115
int k_idx, int n_idx, int m_idx) {
114-
if (mayiuse(avx512_core)) {
116+
if (avx512_gemm_available()) {
115117
if (((m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) * unroll_m_reg_)
116118
== 0)
117119
&& (n_idx == 1)) {
@@ -126,7 +128,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA(
126128

127129
void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA(
128130
int um, int un, int k_idx, int n_idx, int m_idx) {
129-
if (!mayiuse(avx512_core)) {
131+
if (!avx512_gemm_available()) {
130132
if ((um == unroll_m_) && (un == unroll_n_)) {
131133
if (((k_idx == 0) && (n_idx % 2 == 1) && (m_idx == 0))
132134
|| ((k_idx == 1) && (n_idx == 2) && (m_idx == 0))
@@ -160,7 +162,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA(
160162

161163
void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload(
162164
int um, int un, int k_idx, int n_idx) {
163-
if (mayiuse(avx512_core)) {
165+
if (avx512_gemm_available()) {
164166
if (um == unroll_m_) {
165167
if (n_idx == std::min(1, un - 1)) {
166168
if (k_idx == unroll_k_ - 1)
@@ -173,7 +175,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload(
173175
}
174176

175177
void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop(int um) {
176-
if (mayiuse(avx512_core)) {
178+
if (avx512_gemm_available()) {
177179
if (um < unroll_m_) {
178180
prefetchw(ptr[CO2_ + elt_size_ * 0]);
179181
prefetchw(ptr[CO2_ + elt_size_ * 8]);
@@ -228,7 +230,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
228230
mov(C_, ptr[rsp + get_size_of_abi_save_regs() + C_off]);
229231
mov(LDC_, ptr[rsp + get_size_of_abi_save_regs() + LDC_off]);
230232

231-
if (mayiuse(avx512_core)) {
233+
if (avx512_gemm_available()) {
232234
for (i = zmm_acc_idx_; i < unroll_m_reg_ * unroll_n_ + zmm_acc_idx_;
233235
i++)
234236
vpxorq(Xbyak::Zmm(i), Xbyak::Zmm(i), Xbyak::Zmm(i));
@@ -267,7 +269,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
267269
add(AA_, A_);
268270
mov(CO1_, C_);
269271

270-
if ((unroll_x == unroll_m_) || (!mayiuse(avx512_core)))
272+
if ((unroll_x == unroll_m_) || (!avx512_gemm_available()))
271273
lea(CO2_, ptr[C_ + LDC_ * 2]);
272274

273275
add(C_, unroll_x * elt_size_);
@@ -292,12 +294,12 @@ void jit_avx2_kernel_sgemm_kern::generate() {
292294
T_NEAR);
293295
}
294296

295-
if (!mayiuse(avx512_core))
297+
if (!avx512_gemm_available())
296298
prefetcht2(ptr[AA_ - addr_off_ * elt_size_]);
297299

298300
switch (unroll_x) {
299301
case 8:
300-
if (mayiuse(avx512_core)) {
302+
if (avx512_gemm_available()) {
301303
loop<Xbyak::Zmm, Xbyak::Zmm, Xbyak::Address, Xbyak::Xmm,
302304
Xbyak::Operand>(unroll_x, unroll_y,
303305
&Xbyak::CodeGenerator::vbroadcastf64x4,
@@ -319,7 +321,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
319321

320322
break;
321323
case 4:
322-
if (mayiuse(avx512_core)) {
324+
if (avx512_gemm_available()) {
323325
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Address, Xbyak::Xmm,
324326
Xbyak::Operand>(unroll_x, unroll_y,
325327
&Xbyak::CodeGenerator::vbroadcastf32x4,
@@ -340,7 +342,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
340342

341343
break;
342344
case 2:
343-
if (mayiuse(avx512_core)) {
345+
if (avx512_gemm_available()) {
344346
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Operand, Xbyak::Xmm,
345347
Xbyak::Operand>(unroll_x, unroll_y,
346348
&Xbyak::CodeGenerator::vbroadcastsd,
@@ -357,7 +359,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
357359
&Xbyak::CodeGenerator::vmovsd);
358360
break;
359361
case 1:
360-
if (mayiuse(avx512_core)) {
362+
if (avx512_gemm_available()) {
361363
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
362364
Xbyak::Operand>(unroll_x, unroll_y,
363365
&Xbyak::CodeGenerator::vbroadcastss,
@@ -377,7 +379,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
377379

378380
break;
379381
default:
380-
if (mayiuse(avx512_core)) {
382+
if (avx512_gemm_available()) {
381383
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
382384
Xbyak::Operand>(unroll_x, unroll_y,
383385
&Xbyak::CodeGenerator::vmovups,
@@ -400,7 +402,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
400402
break;
401403
}
402404

403-
if (mayiuse(avx512_core)) {
405+
if (avx512_gemm_available()) {
404406
sub(AA_, -16 * elt_size_);
405407
} else {
406408
if ((unroll_y != unroll_n_) || (unroll_x <= 4)) {

0 commit comments

Comments
 (0)