1
1
/* ******************************************************************************
2
- * Copyright 2021-2024 Intel Corporation
2
+ * Copyright 2021-2025 Intel Corporation
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
@@ -151,7 +151,10 @@ struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t,
151
151
template <>
152
152
void jit_brgemm_matmul_copy_a_impl_t <Zmm>::load_vmm(int idx, int offset) {
153
153
const auto addr = EVEX_compress_addr (reg_src, offset);
154
- if (conf_->isa == avx512_core_fp16) {
154
+ if (conf_->isa == avx512_core_fp16
155
+ && conf_->orig_wei_dt == data_type::f16) {
156
+ // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
157
+ // is used.
155
158
vcvtph2psx (get_vmm_copy (idx), addr);
156
159
} else
157
160
vmovdqu8 (get_vmm_copy (idx), addr);
@@ -186,8 +189,12 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_tail(
186
189
}
187
190
};
188
191
189
- const size_t dt_step
190
- = conf_->is_bf32 || conf_->isa == avx512_core_fp16 ? 1 : typesize_;
192
+ // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` is used.
193
+ const size_t dt_step = conf_->is_bf32
194
+ || (conf_->isa == avx512_core_fp16
195
+ && conf_->orig_wei_dt == data_type::f16)
196
+ ? 1
197
+ : typesize_;
191
198
const size_t tail_mask_load = size_t (((size_t )1 << (dt_step * k_tail)) - 1 );
192
199
kmovx (kTail_load , tail_mask_load);
193
200
const int k_tail_st = rnd_up (k_tail, vnni_granularity_);
@@ -202,7 +209,10 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_tail(
202
209
auto load_addr = EVEX_compress_addr (reg_src, offset * typesize_);
203
210
if (conf_->is_bf32 )
204
211
vmovups (zmm_tail, load_addr);
205
- else if (conf_->isa == avx512_core_fp16)
212
+ else if (conf_->isa == avx512_core_fp16
213
+ && conf_->orig_wei_dt == data_type::f16)
214
+ // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
215
+ // is used.
206
216
vcvtph2psx (zmm_tail, load_addr);
207
217
else
208
218
vmovdqu8 (zmm_tail, load_addr);
@@ -223,7 +233,10 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::store_tail(
223
233
Ymm ymm_downcvt_bf16 = Ymm (get_vmm_copy (0 ).getIdx ());
224
234
vcvtneps2bf16 (ymm_downcvt_bf16, get_vmm_copy (0 ));
225
235
vmovdqu16 (tr_src_addr, ymm_downcvt_bf16 | kTail_store );
226
- } else if (conf_->isa == avx512_core_fp16) {
236
+ } else if (conf_->isa == avx512_core_fp16
237
+ && conf_->orig_wei_dt == data_type::f16) {
238
+ // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
239
+ // is used.
227
240
vmovups (tr_src_addr, get_vmm_copy (0 ) | kTail_store );
228
241
} else
229
242
vmovdqu8 (tr_src_addr, get_vmm_copy (0 ) | kTail_store );
@@ -943,13 +956,17 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_f32(
943
956
const auto addr = is_dynamic_src_ld
944
957
? ptr[i % 2 == 0 ? reg_aux_src0 : reg_aux_src1]
945
958
: EVEX_compress_addr (src, i * src_stride);
946
- if (i < nrows)
947
- if (conf_->isa == avx512_core_fp16)
959
+ if (i < nrows) {
960
+ if (conf_->isa == avx512_core_fp16
961
+ && conf_->orig_wei_dt == data_type::f16)
962
+ // See the note in `create_brgemm_matmul_copy_b` why
963
+ // `orig_wei_dt` is used.
948
964
vcvtph2psx (src_zmm (i) | kTail | T_z, addr);
949
965
else
950
966
vmovups (src_zmm (i) | kTail | T_z, addr);
951
- else
967
+ } else {
952
968
vpxord (src_zmm (i), src_zmm (i), src_zmm (i));
969
+ }
953
970
};
954
971
955
972
auto store = [this , dst](Zmm r, int i) {
@@ -1075,7 +1092,11 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_f32(
1075
1092
template <typename Vmm>
1076
1093
void jit_brgemm_matmul_copy_a_transposed_impl_t <Vmm>::deploy_transpose(
1077
1094
reg64_t dst, reg64_t src, int nrows, int ncolumns) {
1078
- if (is_f32 || conf_->isa == avx512_core_fp16)
1095
+ if (is_f32
1096
+ || (conf_->isa == avx512_core_fp16
1097
+ && conf_->orig_wei_dt == data_type::f16))
1098
+ // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
1099
+ // is used.
1079
1100
transpose_f32 (dst, src, nrows, ncolumns);
1080
1101
else
1081
1102
transpose_bf16 (dst, src, nrows, ncolumns);
@@ -3714,7 +3735,12 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
3714
3735
const int columns_tail, const bool use_int4_mask) {
3715
3736
assert (IMPLICATION (use_int4_mask, is_src_int4_));
3716
3737
if (columns_tail > 0 ) {
3717
- const int dt_step = req_cvtps2bf16_ || conf_->isa == avx512_core_fp16
3738
+
3739
+ const int dt_step = req_cvtps2bf16_
3740
+ || (conf_->isa == avx512_core_fp16
3741
+ && conf_->orig_wei_dt == data_type::f16)
3742
+ // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
3743
+ // is used.
3718
3744
? 1
3719
3745
: typesize_;
3720
3746
const auto tail_mask = use_int4_mask
@@ -3870,11 +3896,14 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
3870
3896
3871
3897
auto src_load = columns_tail > 0 ? src_reg | kTail | T_z : src_reg;
3872
3898
const auto addr = EVEX_compress_addr (reg_src, i * src_stride_);
3873
- if (conf_->isa == avx512_core_fp16)
3899
+ if (conf_->isa == avx512_core_fp16
3900
+ && conf_->orig_wei_dt == data_type::f16) {
3901
+ // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
3902
+ // is used.
3874
3903
vcvtph2psx (src_load, addr);
3875
- else
3904
+ } else {
3876
3905
vmovdqu8 (src_load, addr);
3877
-
3906
+ }
3878
3907
L (load_done);
3879
3908
};
3880
3909
@@ -4687,7 +4716,10 @@ status_t create_brgemm_matmul_copy_b(
4687
4716
else
4688
4717
CHECK (safe_ptr_assign (copy_ker,
4689
4718
new jit_brgemm_matmul_copy_b_bf16_t <Ymm>(conf)));
4690
- } else if (is_f32 || conf->isa == avx512_core_fp16) {
4719
+ } else if (is_f32
4720
+ || (conf->isa == avx512_core_fp16
4721
+ && conf->orig_wei_dt == data_type::f16)) {
4722
+ // See the note above why `orig_wei_dt` is used.
4691
4723
if (is_superset (conf->isa , avx512_core))
4692
4724
CHECK (safe_ptr_assign (copy_ker,
4693
4725
new jit_brgemm_matmul_copy_b_f32_t <Zmm>(conf)));
0 commit comments