Skip to content

Commit 7a6ee1c

Browse files
amakarevliubo-intel
authored andcommitted
[rls-v3.6 backport] cpu: x64: matmul: copy_routines: associate fp16 support with dt, not only isa (uxlfoundation#2332)
1 parent b08b241 commit 7a6ee1c

File tree

1 file changed

+47
-15
lines changed

1 file changed

+47
-15
lines changed

src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp

+47-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2024 Intel Corporation
2+
* Copyright 2021-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -151,7 +151,10 @@ struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t,
151151
template <>
152152
void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_vmm(int idx, int offset) {
153153
const auto addr = EVEX_compress_addr(reg_src, offset);
154-
if (conf_->isa == avx512_core_fp16) {
154+
if (conf_->isa == avx512_core_fp16
155+
&& conf_->orig_wei_dt == data_type::f16) {
156+
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
157+
// is used.
155158
vcvtph2psx(get_vmm_copy(idx), addr);
156159
} else
157160
vmovdqu8(get_vmm_copy(idx), addr);
@@ -186,8 +189,12 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_tail(
186189
}
187190
};
188191

189-
const size_t dt_step
190-
= conf_->is_bf32 || conf_->isa == avx512_core_fp16 ? 1 : typesize_;
192+
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` is used.
193+
const size_t dt_step = conf_->is_bf32
194+
|| (conf_->isa == avx512_core_fp16
195+
&& conf_->orig_wei_dt == data_type::f16)
196+
? 1
197+
: typesize_;
191198
const size_t tail_mask_load = size_t(((size_t)1 << (dt_step * k_tail)) - 1);
192199
kmovx(kTail_load, tail_mask_load);
193200
const int k_tail_st = rnd_up(k_tail, vnni_granularity_);
@@ -202,7 +209,10 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_tail(
202209
auto load_addr = EVEX_compress_addr(reg_src, offset * typesize_);
203210
if (conf_->is_bf32)
204211
vmovups(zmm_tail, load_addr);
205-
else if (conf_->isa == avx512_core_fp16)
212+
else if (conf_->isa == avx512_core_fp16
213+
&& conf_->orig_wei_dt == data_type::f16)
214+
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
215+
// is used.
206216
vcvtph2psx(zmm_tail, load_addr);
207217
else
208218
vmovdqu8(zmm_tail, load_addr);
@@ -223,7 +233,10 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::store_tail(
223233
Ymm ymm_downcvt_bf16 = Ymm(get_vmm_copy(0).getIdx());
224234
vcvtneps2bf16(ymm_downcvt_bf16, get_vmm_copy(0));
225235
vmovdqu16(tr_src_addr, ymm_downcvt_bf16 | kTail_store);
226-
} else if (conf_->isa == avx512_core_fp16) {
236+
} else if (conf_->isa == avx512_core_fp16
237+
&& conf_->orig_wei_dt == data_type::f16) {
238+
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
239+
// is used.
227240
vmovups(tr_src_addr, get_vmm_copy(0) | kTail_store);
228241
} else
229242
vmovdqu8(tr_src_addr, get_vmm_copy(0) | kTail_store);
@@ -943,13 +956,17 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_f32(
943956
const auto addr = is_dynamic_src_ld
944957
? ptr[i % 2 == 0 ? reg_aux_src0 : reg_aux_src1]
945958
: EVEX_compress_addr(src, i * src_stride);
946-
if (i < nrows)
947-
if (conf_->isa == avx512_core_fp16)
959+
if (i < nrows) {
960+
if (conf_->isa == avx512_core_fp16
961+
&& conf_->orig_wei_dt == data_type::f16)
962+
// See the note in `create_brgemm_matmul_copy_b` why
963+
// `orig_wei_dt` is used.
948964
vcvtph2psx(src_zmm(i) | kTail | T_z, addr);
949965
else
950966
vmovups(src_zmm(i) | kTail | T_z, addr);
951-
else
967+
} else {
952968
vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
969+
}
953970
};
954971

955972
auto store = [this, dst](Zmm r, int i) {
@@ -1075,7 +1092,11 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_f32(
10751092
template <typename Vmm>
10761093
void jit_brgemm_matmul_copy_a_transposed_impl_t<Vmm>::deploy_transpose(
10771094
reg64_t dst, reg64_t src, int nrows, int ncolumns) {
1078-
if (is_f32 || conf_->isa == avx512_core_fp16)
1095+
if (is_f32
1096+
|| (conf_->isa == avx512_core_fp16
1097+
&& conf_->orig_wei_dt == data_type::f16))
1098+
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
1099+
// is used.
10791100
transpose_f32(dst, src, nrows, ncolumns);
10801101
else
10811102
transpose_bf16(dst, src, nrows, ncolumns);
@@ -3714,7 +3735,12 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
37143735
const int columns_tail, const bool use_int4_mask) {
37153736
assert(IMPLICATION(use_int4_mask, is_src_int4_));
37163737
if (columns_tail > 0) {
3717-
const int dt_step = req_cvtps2bf16_ || conf_->isa == avx512_core_fp16
3738+
3739+
const int dt_step = req_cvtps2bf16_
3740+
|| (conf_->isa == avx512_core_fp16
3741+
&& conf_->orig_wei_dt == data_type::f16)
3742+
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
3743+
// is used.
37183744
? 1
37193745
: typesize_;
37203746
const auto tail_mask = use_int4_mask
@@ -3870,11 +3896,14 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
38703896

38713897
auto src_load = columns_tail > 0 ? src_reg | kTail | T_z : src_reg;
38723898
const auto addr = EVEX_compress_addr(reg_src, i * src_stride_);
3873-
if (conf_->isa == avx512_core_fp16)
3899+
if (conf_->isa == avx512_core_fp16
3900+
&& conf_->orig_wei_dt == data_type::f16) {
3901+
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
3902+
// is used.
38743903
vcvtph2psx(src_load, addr);
3875-
else
3904+
} else {
38763905
vmovdqu8(src_load, addr);
3877-
3906+
}
38783907
L(load_done);
38793908
};
38803909

@@ -4687,7 +4716,10 @@ status_t create_brgemm_matmul_copy_b(
46874716
else
46884717
CHECK(safe_ptr_assign(copy_ker,
46894718
new jit_brgemm_matmul_copy_b_bf16_t<Ymm>(conf)));
4690-
} else if (is_f32 || conf->isa == avx512_core_fp16) {
4719+
} else if (is_f32
4720+
|| (conf->isa == avx512_core_fp16
4721+
&& conf->orig_wei_dt == data_type::f16)) {
4722+
// See the note above why `orig_wei_dt` is used.
46914723
if (is_superset(conf->isa, avx512_core))
46924724
CHECK(safe_ptr_assign(copy_ker,
46934725
new jit_brgemm_matmul_copy_b_f32_t<Zmm>(conf)));

0 commit comments

Comments
 (0)