forked from uxlfoundation/oneDNN
-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathjit_brgemm_kernel.cpp
3733 lines (3320 loc) · 155 KB
/
jit_brgemm_kernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*******************************************************************************
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include <vector>
#include "common/c_types_map.hpp"
#include "common/nstl.hpp"
#include "common/type_helpers.hpp"
#include "common/utils.hpp"
#include "cpu/platform.hpp"
#include "cpu/x64/brgemm/brgemm_types.hpp"
#include "cpu/x64/cpu_barrier.hpp"
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
#include "cpu/x64/jit_avx512_core_fp8cvt.hpp"
#include "cpu/x64/jit_generator.hpp"
#define GET_OFF(field) offsetof(brgemm_kernel_params_t, field)
#define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field)
namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
using namespace dnnl::impl::utils;
using namespace Xbyak;
template <typename Wmm>
struct jit_brgemm_kernel_t : public jit_generator {
jit_brgemm_kernel_t(const brgemm_desc_t &abrg)
: jit_generator(jit_name(), abrg.isa_impl)
, brg(abrg)
, postops_injector_(nullptr)
, max_effective_vregs(get_max_effective_vregs(abrg)) {
// The implementation uses is_superset(), is_subset() utilities.
// So avoid isa_all, isa_undef in these comparisions.
assert(!utils::one_of(brg.isa_impl, isa_all, isa_undef));
const int is_ldb2_tail = brg.ldb2_tail ? 1 : 0;
const int is_ldb_tail = brg.ldb_tail ? 1 : 0;
is_ldb_loop_ = brg.ldb2 + is_ldb2_tail + is_ldb_tail > 1;
bool has_f8_e5m2_binary_postops = false;
bool has_f8_e4m3_binary_postops = false;
if (brg.with_binary) {
const auto &post_ops = brg.attr()->post_ops_;
for (int i = 0; i < post_ops.len(); i++) {
const auto &entry = post_ops.entry_[i];
if (!entry.is_binary()) continue;
has_f8_e5m2_binary_postops
= entry.binary.src1_desc.data_type == data_type::f8_e5m2
|| has_f8_e5m2_binary_postops;
has_f8_e4m3_binary_postops
= entry.binary.src1_desc.data_type == data_type::f8_e4m3
|| has_f8_e4m3_binary_postops;
}
}
if (brg.is_fp8_via_convert() || has_f8_e5m2_binary_postops
|| has_f8_e4m3_binary_postops) {
if (one_of(data_type::f8_e5m2, brg.dt_a, brg.dt_b, brg.dt_c,
brg.dt_d)
|| has_f8_e5m2_binary_postops)
// Note: avoid using 'vmm0' since it is used as
// 'fp8_to_f16_upconvert()' param and would collision with these
// emulation vmms
f8_e5m2_emulator_ = utils::make_unique<fp8_emulation_e5m2_t>(
this, xmm_fp8_emu_aux1, xmm_fp8_emu_aux2,
xmm_fp8_emu_aux3, kmask_fp8_aux, reg64_fp8_aux);
if (one_of(data_type::f8_e4m3, brg.dt_a, brg.dt_b, brg.dt_c,
brg.dt_d)
|| has_f8_e4m3_binary_postops)
f8_e4m3_emulator_ = utils::make_unique<fp8_emulation_e4m3_t>(
this, xmm_fp8_emu_aux1, xmm_fp8_emu_aux2,
xmm_fp8_emu_aux3, xmm_fp8_emu_aux4, xmm_fp8_emu_aux5,
reg64_fp8_aux);
}
if (brg.with_eltwise || brg.with_binary || brg.with_sum) {
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = true;
static constexpr bool use_exact_tail_scalar_bcast = false;
const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md());
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(vmm_tmp(0).getIdx()), this->r14,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
ld_tail_mask, use_exact_tail_scalar_bcast};
const binary_injector::static_params_t bsp {this->param1,
binary_injector::get_all_strategies_supported_by_injector(),
rhs_sp, f8_e5m2_emulator_.get(), f8_e4m3_emulator_.get()};
auto st = safe_ptr_assign(postops_injector_,
po_injector_t::create(
this, brg.isa_impl, brg.attr()->post_ops_, bsp));
if (st != status::success) {
assert(!"postops_injector creation failed");
}
with_binary_non_scalar_bcast_ = binary_injector::
any_binary_postop_rhs_non_scalar_broadcast(
brg.attr()->post_ops_, dst_md_wrapper);
}
if (brg.is_bf16_emu)
bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
bf16_emu_reserv_1(), bf16_emu_reserv_2(),
bf16_emu_reserv_3(), bf16_emu_scratch, bf16_emu_reserv_4(),
bf16_emu_reserv_4());
}
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_t)
brgemm_desc_t brg;
private:
enum matrix_kind_t { matrix_A, matrix_B };
static constexpr int zmm_width_in_bytes_
= cpu_isa_traits<avx512_core>::vlen;
using Vmm =
typename utils::conditional<std::is_same<Wmm, Xbyak::Tmm>::value,
Xbyak::Zmm, Wmm>::type;
using Vmm_lower_t = typename vreg_traits<Vmm>::Vmm_lower_t;
using po_injector_t = injector::jit_uni_postops_injector_base_t<Vmm>;
std::unique_ptr<po_injector_t> postops_injector_;
std::unique_ptr<bf16_emulation_t> bf16_emu_;
std::unique_ptr<fp8_emulation_e5m2_t> f8_e5m2_emulator_;
std::unique_ptr<fp8_emulation_e4m3_t> f8_e4m3_emulator_;
Xbyak::Label avx_tail_mask_;
Xbyak::Label sum_zp_scale_data_;
Xbyak::Label f16_perm_even_table_;
Xbyak::Label f16_perm_odd_table_;
using reg64_t = const Xbyak::Reg64;
// Register decomposition
const reg64_t param1 = abi_param1;
const reg64_t reg_C = r15;
const reg64_t reg_aux_C = r14;
const reg64_t reg_addr_batch = r13;
const reg64_t reg_A = r13;
const reg64_t reg_B = r12;
const reg64_t reg_aux_A = r11;
const reg64_t reg_aux_B = r10;
const reg64_t reg_aux_A_vpad = reg_aux_A;
const reg64_t reg_bdb_loop = r9;
const reg64_t reg_ldb_loop = r8;
const reg64_t reg_stride_lda = reg_bdb_loop;
const reg64_t reg_stride_ldb = reg_ldb_loop;
const reg64_t reg_stride_ld_block = reg_ldb_loop;
const reg64_t reg_s8_input_shift = reg_bdb_loop;
const reg64_t reg_zp_a_input_shift = reg_bdb_loop;
const reg64_t reg_BS_loop = rax;
const reg64_t reg_rdb_loop = rbx;
const reg64_t reg_BS = abi_not_param1;
const reg64_t reg_a_offset = rdx;
const reg64_t reg_b_offset = rsi;
const reg64_t reg_aux1_batch = rbp;
const reg64_t reg_aux1_A = rbp;
const reg64_t reg_aux1_B = abi_param1;
const reg64_t reg_offs_batch = reg_aux1_A;
const reg64_t reg_strd_batch = reg_rdb_loop;
const reg64_t reg_bias = reg_rdb_loop;
const reg64_t reg_scales = reg_rdb_loop;
const reg64_t reg_aux_bias = reg_rdb_loop;
const reg64_t reg_dst_scales = reg_rdb_loop;
const reg64_t reg_zp_comp_a = reg_rdb_loop;
const reg64_t reg_aux_zp_comp_a = reg_rdb_loop;
const reg64_t reg_zp_comp_b = reg_rdb_loop;
const reg64_t reg_aux_zp_comp_b = reg_rdb_loop;
const reg64_t reg_zp_c_values = reg_rdb_loop;
const reg64_t reg_aux_zp_c_values = reg_rdb_loop;
const reg64_t reg_wei_scales = reg_rdb_loop;
const reg64_t reg_aux_wei_scales = reg_rdb_loop;
const reg64_t reg_wei_zp = reg_rdb_loop;
const reg64_t reg_aux_wei_zp = reg_rdb_loop;
const reg64_t reg_ic = reg_rdb_loop;
const reg64_t reg_src_scales = reg_rdb_loop;
const reg64_t reg_src_grouped_sum = reg_rdb_loop;
const reg64_t reg_tmp_read_values = reg_rdb_loop;
const reg64_t reg_aux_scales = reg_aux_B;
const reg64_t reg_aux_dst_scales = reg_aux_B;
const reg64_t reg_do_post_ops = reg_rdb_loop;
const reg64_t reg_do_comp = reg_rdb_loop;
const reg64_t reg_skip_accm = reg_rdb_loop;
const reg64_t reg_tmp_gpr = reg_rdb_loop;
const reg64_t reg_ptr_sum_scale = reg_rdb_loop;
const reg64_t reg_ptr_sum_zp = reg_bdb_loop;
const reg64_t reg_zp_a_val = reg_rdb_loop;
const reg64_t reg_buf = reg_rdb_loop;
const reg64_t reg_buf_aux = abi_param1;
const reg64_t reg_compensation = reg_rdb_loop;
const reg64_t reg_aux_compensation = reg_rdb_loop;
const reg64_t reg_D = reg_aux_A;
const reg64_t reg_aux_D = reg_BS_loop;
/* bf16 emulation */
const reg64_t bf16_emu_scratch = reg_rdb_loop;
// FP8 Convert
// Note: registers (GPR and VMM) used in the fp8 emulator should not
// intersect with the set of registers used in binary injector because fp8
// emulator may be called from injector
const reg64_t reg_converted_stride = reg_rdb_loop;
const reg64_t reg64_fp8_aux = reg_a_offset;
constexpr static int origin_offs_batch_offs_ = 0;
constexpr static int origin_strd_batch_offs_ = 0;
constexpr static int reg_bias_offs_ = 8;
constexpr static int reg_aux_bias_offs_ = 16;
constexpr static int reg_do_post_ops_offs_ = 24;
constexpr static int reg_D_offs_ = 32;
constexpr static int reg_aux_D_offs_ = 40;
constexpr static int reg_scales_offs_ = 48;
constexpr static int reg_aux_scales_offs_ = 56;
constexpr static int reg_bdb_loop_offs_ = 64;
constexpr static int reg_ldb_loop_offs_ = 72;
constexpr static int reg_buf_offs_ = 80;
constexpr static int reg_comp_offs_ = reg_buf_offs_;
constexpr static int reg_aux_comp_offs_ = 88;
constexpr static int abi_param1_offs_ = 96;
constexpr static int reg_zp_comp_a_offs_ = 104;
constexpr static int reg_aux_zp_comp_a_offs_ = 112;
constexpr static int reg_zp_comp_b_offs_ = 120;
constexpr static int reg_aux_zp_comp_b_offs_ = 128;
constexpr static int reg_zp_c_values_offs_ = 136;
constexpr static int reg_aux_zp_c_values_offs_ = 144;
constexpr static int reg_data_C_ptr_ = 152;
constexpr static int reg_skip_accm_offs_ = 160;
constexpr static int reg_zp_a_val_offs_ = 168;
constexpr static int reg_do_comp_offs_ = 176;
constexpr static int reg_dst_scales_offs_ = 184;
constexpr static int reg_C_shift_bytes_offs_ = 192;
constexpr static int reg_aux_C_backup_offs_ = 200;
constexpr static int reg_aux_C_bdb_loop_backup_offs_ = 208;
constexpr static int reg_aux_C_bdb_loop_shift_offs_ = 216;
constexpr static int reg_D_shift_bytes_offs_ = 224;
constexpr static int reg_aux_D_backup_offs_ = 232;
constexpr static int reg_aux_D_bdb_loop_backup_offs_ = 240;
constexpr static int reg_aux_D_bdb_loop_shift_offs_ = 248;
constexpr static int reg_wei_scales_offs_ = 256;
constexpr static int reg_aux_wei_scales_offs_ = 264;
constexpr static int reg_wei_zero_points_offs_ = 272;
constexpr static int reg_aux_wei_zero_points_offs_ = 280;
constexpr static int reg_ic_offs_ = 288;
constexpr static int reg_aux2_D_offs_ = 296;
constexpr static int reg_aux2_wei_scales_offs_ = 304;
constexpr static int reg_aux2_wei_zero_points_offs_ = 312;
constexpr static int reg_aux_ic_offs_ = 320;
constexpr static int reg_reg_a_offset_offs_ = 328;
constexpr static int reg_src_scales_offs_ = 336;
constexpr static int reg_aux_src_scales_offs_ = 344;
constexpr static int reg_aux2_src_scales_offs_ = 352;
constexpr static int reg_src_grouped_sum_offs_ = 360;
constexpr static int reg_aux_src_grouped_sum_offs_ = 368;
constexpr static int reg_aux2_src_grouped_sum_offs_ = 376;
// these are used for FP8 as temporary push/pop spaces
constexpr static int reg_val_tmp_1_ = 384;
constexpr static int reg_val_tmp_2_ = 392;
constexpr static int stack_space_needed_ = 400;
bool is_ldb_loop_ = false;
bool with_binary_non_scalar_bcast_ = false;
const int max_effective_vregs;
Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2);
Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3);
Xbyak::Opmask fp8_col_mask = Xbyak::Opmask(4);
Xbyak::Opmask kmask_fp8_aux = Xbyak::Opmask(5);
static int get_max_effective_vregs(const brgemm_desc_t &brg) {
auto used_vregs = 0;
if (brg.is_int8 && !brg.has_int8_vnni)
used_vregs = 2;
else if (brg.is_fp8_via_convert())
used_vregs = 5;
else if (brg.is_f16_b_non_amx_vnni())
used_vregs = 2;
if (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl == avx2) {
used_vregs += 5;
}
if (one_of(brg.dt_b, data_type::f4_e2m1) && brg.isa_impl == avx2) {
used_vregs += 2;
}
if (one_of(brg.dt_b, data_type::nf4, data_type::f4_e2m1) && brg.isa_impl != avx2) {
used_vregs += 1;
}
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 && !brg.with_src_dyn_quant) {
used_vregs += 1;
}
if (brg.with_src_dyn_quant) {
used_vregs += 1;
}
return isa_num_vregs(brg.isa_impl) - used_vregs;
}
Vmm accm(int ld_block, int bd, int ld) {
return Vmm(max_effective_vregs - 1 - (bd * ld_block + ld));
}
Vmm bcst(int bd = 0) {
if (n_bcast_1_load) {
int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block)
- bd;
assert(idx > 0);
return Vmm(idx);
} else
return Vmm(0);
}
Vmm load(int ld = 0) {
if (n_bcast_1_load) {
return Vmm(0);
} else {
int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block)
- ld;
assert(idx > 0);
return Vmm(idx);
}
}
Vmm vmm_tmp(int i) {
assert(IMPLICATION(!brg.is_tmm,
i >= 0
&& i < max_effective_vregs
- brg.bd_block * brg.ld_block2));
return Vmm(i);
}
Vmm vmm_tail_mask() { return vmm_tmp(1); }
Vmm vmm_one_bytes() const noexcept { return Vmm(3); }
Vmm vmm_zp_a_shift() const noexcept { return Vmm(2); }
Vmm vmm_inp_shift() const noexcept { return Vmm(1); }
/* bf16 emulation */
Zmm bf16_emu_reserv_1() const noexcept { return Zmm(0); }
Zmm bf16_emu_reserv_2() const noexcept { return Zmm(1); }
Zmm bf16_emu_reserv_3() const noexcept { return Zmm(2); }
Zmm bf16_emu_reserv_4() const noexcept { return Zmm(3); }
// note: zmm reserv_5 is not necessary since it's only used for 'vdpbf16ps'
// fp8 emulation convert
Vmm xmm_fp8_emu_aux1 = Vmm(1);
Vmm xmm_fp8_emu_aux2 = Vmm(2);
Vmm xmm_fp8_emu_aux3 = Vmm(3);
Vmm xmm_fp8_emu_aux4 = Vmm(4);
Vmm xmm_fp8_emu_aux5 = Vmm(5);
// Required in every dot product for INT8 non-VNNI computation.
Vmm int8_ones_words() const noexcept {
return Vmm(isa_num_vregs(brg.isa_impl) - 1);
}
Vmm int8_dot_product_temp() const noexcept {
return Vmm(isa_num_vregs(brg.isa_impl) - 2);
}
Zmm f16_perm_even_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 1);
Zmm f16_perm_odd_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 2);
Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask) const;
Vmm_lower_t vmm_lower_mask(const Vmm_lower_t vmm_lower_in, bool mask_flag,
bool store, Xbyak::Opmask ktail_mask) const;
void maybe_set_avx_mask(bool is_ld_tail);
void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op,
bool mask_flag, bool store, Xbyak::Opmask ktail_mask,
int tail_size);
void advance_ldb_post_op_regs();
void restore_ldb_post_op_regs(int ld_block2);
void advance_bdb_post_op_regs(int adj_bd_block);
void restore_bdb_post_op_regs(int bd_block2);
void ldb_regs_shift(int ld_block2, bool is_tail = false);
void advance_bd_block2_post_op_regs(int bd_block2);
void copy_post_ops_stack_values_to_aux(bool is_reg_tail);
void read_params();
void zero_accumulators(int bd_block2, bool is_bdb_tail, int ld_block,
bool is_ld_tail, bool skip_accumulation);
void fp8_to_f16_upconvert(int num_rows, int tile_num_col_bytes,
reg64_t reg_base, int offset, reg64_t reg_data_stride,
data_type_t dt, bool is_rd_tail);
void fp8_to_f16_upconvert_to_vnni(int num_rows, int tile_num_col_bytes,
reg64_t reg_base, int offset, reg64_t reg_data_stride,
data_type_t dt, bool is_rd_tail);
void store_accumulators(int bd_block2, bool is_bdb_tail, int ld_block,
bool is_ld_tail, bool skip_accumulation);
void store_accumulators_without_post_ops(
int bd_block, int ld_block, bool is_ld_tail);
void store_accumulators_apply_post_ops(int bd_block, int ld_block,
int ldb_and_bdb_offset, bool is_ld_tail);
void apply_compensation(int bd_block, int ld_block, bool is_ld_tail);
void apply_alpha_beta(int bd_block, int ld_block, bool is_ld_tail);
void apply_post_ops(int bd_block, int ld_block2, int ldb_and_bdb_offset,
bool is_ld_tail);
void restore_A_B_matrices();
void set_A_B_matrices();
void compute_int8_compensation(int rd_loop, int bd_b, int bd_e,
int bd_block, int ld_block2, bool is_ld_tail, int vpad);
void maybe_pre_process_data(matrix_kind_t matrix_kind, const Tmm &t1,
reg64_t reg_base, size_t offset, reg64_t reg_stride, int num_rows,
int num_col_bytes, bool is_rd_tail);
void maybe_tileloadd_nt(matrix_kind_t matrix_kind, int idx, int offset,
bool is_rd_tail, bool is_tail);
void dot_product(Vmm v1, Vmm v2, Vmm v3);
void gemm_microkernel(int bd_block2, bool is_bdb_tail, int ld_block,
bool is_rd_tail, bool is_ld_tail, int vpad, int rows_for_rd_tail);
void gemm_microkernel_amx(int bd_block2, bool is_bdb_tail, int ld_block,
bool is_rd_tail, bool is_ld_tail);
void gemm_microkernel_dyn_quant(int bd_block2, bool is_bdb_tail, int ld_block,
bool is_rd_tail, bool is_ld_tail, int vpad, int rows_for_rd_tail);
void ldb_loop(int bd_block2, bool is_bdb_tail, int ld_block,
int ldb_loop_length, bool is_reg_tail, bool is_ld_tail,
bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail,
bool skip_accumulation);
void bdb_loop();
void generate() override;
int A_offset(int bd, int rd, bool is_amx = false) const noexcept;
int B_offset(int ld, int rd, bool is_amx = false) const noexcept;
int C_offset(int bd, int ld) const noexcept;
int D_offset(int bd, int ld) const noexcept;
int rdb_A_offset() const noexcept;
int rdb_B_offset() const noexcept;
int ldb_B_offset(int ld_block2, bool is_tail = false) const noexcept;
int ldb_C_offset(int ld_block2, bool is_tail = false) const noexcept;
int ldb_D_offset(int ld_block2, bool is_tail = false) const noexcept;
int ldb_po_offset(int ld_block2, bool is_tail = false) const noexcept;
int bdb_A_offset(int bd_block2) const noexcept;
int bdb_C_offset(int bd_block2) const noexcept;
int bdb_D_offset(int bd_block2) const noexcept;
int bdb_po_offset(int bd_block2) const noexcept;
int bias_offset(int ld, bool is_tail = false) const noexcept;
int oc_logical_offset(int ld, bool is_tail = false) const noexcept;
int compensations_offset(int ld, bool is_tail = false) const noexcept;
int bdb_compensation_offset(int bd_block2) const noexcept;
int bd_compensation_offset(int ld, int bd) const noexcept;
int scales_offset(int ld, bool is_tail = false) const noexcept;
int wei_scales_offset(int ld, bool is_tail = false) const noexcept;
int wei_zp_offset(int ld, bool is_tail = false) const noexcept;
int zp_comp_a_offset(int ld, bool is_tail = false) const noexcept;
int bd_zp_comp_a_offset(int ld, int bd) const noexcept;
int bdb_zp_comp_a_offset(int bd_block2) const noexcept;
int zp_comp_b_offset(int bd) const noexcept;
int bdb_zp_comp_b_offset(int bd_block2) const noexcept;
int zp_c_values_offset(int ld, bool is_tail = false) const noexcept;
bool n_bcast_1_load = false;
bool vpad_exist = false;
bool need_comp_pads = false;
};
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::A_offset(
int bd, int rd, bool is_amx) const noexcept {
return (is_amx) ? brg.typesize_A * (bd * brg.bd_block * brg.LDA)
: brg.typesize_A * (bd * brg.LDA + rd);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::B_offset(
int ld, int rd, bool is_amx) const noexcept {
int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1;
if (is_amx) {
return brg.typesize_B * (brg.rd_step * ld * brg.ld_block) / typesize_scale;
} else {
const int data_vnni_granularity = brg.is_f16_b_non_amx_vnni()
? data_type_vnni_granularity(data_type::f16)
: brg.ld_step;
const int rdb0 = rd / data_vnni_granularity;
// Note: Offsets for elements within vnni_granularity are expected to be
// handled within gemm_microkernel (for ex: odd-even converts).
// hence no `rd % data_vnni_granularity`
return brg.typesize_B
* (rdb0 * data_vnni_granularity * brg.LDB
+ data_vnni_granularity * ld * brg.ld_block) / typesize_scale;
}
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::C_offset(int bd, int ld) const noexcept {
const auto bd_shift = brg.is_runtime_ldc ? 0 : bd * brg.LDC;
return brg.typesize_C * (bd_shift + ld * brg.ld_block);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::D_offset(int bd, int ld) const noexcept {
const auto bd_shift = brg.is_runtime_ldd ? 0 : bd * brg.LDD;
return brg.typesize_D * (bd_shift + ld * brg.ld_block);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::rdb_A_offset() const noexcept {
return brg.typesize_A * brg.rd_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::rdb_B_offset() const noexcept {
int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1;
return brg.typesize_B * brg.rd_block * brg.LDB / typesize_scale;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::ldb_B_offset(
int ld_block2, bool is_tail) const noexcept {
const int data_vnni_granularity = brg.is_f16_b_non_amx_vnni()
? data_type_vnni_granularity(data_type::f16)
: brg.ld_step;
int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1;
return (is_tail)
? brg.typesize_B * brg.ldb_tail * data_vnni_granularity / typesize_scale
: brg.typesize_B * ld_block2 * brg.ld_block * data_vnni_granularity / typesize_scale;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::ldb_C_offset(
int ld_block2, bool is_tail) const noexcept {
return (is_tail) ? brg.typesize_C * brg.ldb_tail
: brg.typesize_C * ld_block2 * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::ldb_D_offset(
int ld_block2, bool is_tail) const noexcept {
return (is_tail) ? brg.typesize_D * brg.ldb_tail
: brg.typesize_D * ld_block2 * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::ldb_po_offset(
int ld_block2, bool is_tail) const noexcept {
return (is_tail) ? brg.ldb_tail : ld_block2 * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bdb_A_offset(int bd_block2) const noexcept {
return brg.typesize_A * bd_block2 * brg.bd_block * brg.LDA;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bdb_C_offset(int bd_block2) const noexcept {
return bd_block2 * brg.bd_block
* (brg.is_runtime_ldc ? 1 : brg.typesize_C * brg.LDC);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bdb_D_offset(int bd_block2) const noexcept {
return bd_block2 * brg.bd_block
* (brg.is_runtime_ldd ? 1 : brg.typesize_D * brg.LDD);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bdb_po_offset(int bd_block2) const noexcept {
return bd_block2 * brg.bd_block * brg.LDD;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bias_offset(int ld, bool is_tail) const noexcept {
return (is_tail) ? brg.typesize_bias * brg.ldb_tail
: brg.typesize_bias * ld * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::oc_logical_offset(
int ld, bool is_tail) const noexcept {
return (is_tail) ? brg.ldb_tail : ld * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::compensations_offset(
int ld, bool is_tail) const noexcept {
return (is_tail) ? sizeof(int32_t) * brg.ldb_tail
: sizeof(int32_t) * ld * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bdb_compensation_offset(
int bd_block2) const noexcept {
return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bd_compensation_offset(
int ld, int bd) const noexcept {
return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::scales_offset(
int ld, bool is_tail) const noexcept {
return (is_tail) ? brg.is_oc_scale * sizeof(float) * brg.ldb_tail
: brg.is_oc_scale * sizeof(float) * ld * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::wei_scales_offset(
int ld, bool is_tail) const noexcept {
return (is_tail) ? types::data_type_size(brg.wei_decomp_scales_dt) * brg.ldb_tail
: types::data_type_size(brg.wei_decomp_scales_dt) * ld * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::wei_zp_offset(
int ld, bool is_tail) const noexcept {
return (is_tail) ? types::data_type_size(brg.wei_decomp_zero_points_dt) * brg.ldb_tail
: types::data_type_size(brg.wei_decomp_zero_points_dt) * ld * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::zp_comp_a_offset(
int ld, bool is_tail) const noexcept {
return (is_tail) ? sizeof(int32_t) * brg.ldb_tail
: sizeof(int32_t) * ld * brg.ld_block;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bdb_zp_comp_a_offset(
int bd_block2) const noexcept {
return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bd_zp_comp_a_offset(
int ld, int bd) const noexcept {
return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::zp_comp_b_offset(int bd) const noexcept {
return sizeof(int32_t) * bd;
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::bdb_zp_comp_b_offset(
int bd_block2) const noexcept {
return zp_comp_b_offset(bd_block2 * brg.bd_block);
}
template <typename Wmm>
int jit_brgemm_kernel_t<Wmm>::zp_c_values_offset(
int ld, bool is_tail) const noexcept {
if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
return (is_tail) ? sizeof(int32_t) * brg.ldb_tail
: sizeof(int32_t) * ld * brg.ld_block;
}
return 0;
}
template <typename Wmm>
typename jit_brgemm_kernel_t<Wmm>::Vmm jit_brgemm_kernel_t<Wmm>::vmm_mask(
const Vmm vmm_in, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask) const {
return mask_flag && is_superset(brg.isa_impl, avx512_core)
? (store ? vmm_in | ktail_mask : vmm_in | ktail_mask | T_z)
: vmm_in;
}
template <typename Wmm>
typename jit_brgemm_kernel_t<Wmm>::Vmm_lower_t
jit_brgemm_kernel_t<Wmm>::vmm_lower_mask(const Vmm_lower_t vmm_lower_in,
bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const {
return mask_flag && is_superset(brg.isa_impl, avx512_core)
? (store ? vmm_lower_in | ktail_mask
: vmm_lower_in | ktail_mask | T_z)
: vmm_lower_in;
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::maybe_set_avx_mask(bool is_ld_tail) {
if (IMPLICATION(is_ld_tail, isa_has_masks(brg.isa_impl))) return;
vmovups(vmm_tail_mask(), ptr[rip + avx_tail_mask_]);
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::cvt2ps(data_type_t type_in, const Vmm vmm_in,
const Xbyak::Operand &op, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask, int tail_size) {
Vmm vmm = vmm_in;
const bool has_tail
= op.isMEM() && tail_size != vreg_traits<Vmm>::vlen / sizeof(float);
if (IMPLICATION(has_tail, is_superset(brg.isa_impl, avx512_core))) {
vmm = vmm_mask(vmm_in, mask_flag, store, ktail_mask);
} else {
load_data(type_in, vmm_in, op.getAddress(), tail_size);
if (types::is_integral_dt(type_in)) uni_vcvtdq2ps(vmm_in, vmm_in);
return;
}
switch (type_in) {
case data_type::f32:
case data_type::s32: uni_vmovups(vmm, op); break;
case data_type::bf16:
uni_vpmovzxwd(vmm, op);
uni_vpslld(vmm, vmm, 16);
break;
case data_type::f16: vcvtph2ps(vmm, op); break;
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
case data_type::f8_e5m2:
if (brg.is_fp8_via_convert())
f8_e5m2_emulator_->vcvt_f8_to_f32(vmm, op);
else
assert(!"Error, native conversion unsupported");
break;
case data_type::f8_e4m3:
if (brg.is_fp8_via_convert())
f8_e4m3_emulator_->vcvt_f8_to_f32(vmm, op);
else
assert(!"Error, native conversion unsupported");
break;
default: assert(!"unsupported data type");
}
if (types::is_integral_dt(type_in)) uni_vcvtdq2ps(vmm_in, vmm_in);
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::advance_ldb_post_op_regs() {
if (brg.with_bias) {
mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]);
add(reg_aux_bias, bias_offset(1));
mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias);
}
if (brg.with_scales) {
mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]);
add(reg_aux_scales, scales_offset(1));
mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
}
if (brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
add(reg_aux_zp_comp_a, zp_comp_a_offset(1));
mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
}
if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
add(reg_aux_zp_c_values, zp_c_values_offset(1));
mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values);
}
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::restore_ldb_post_op_regs(int ld_block2) {
if (brg.with_bias) {
mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]);
sub(reg_aux_bias, bias_offset(ld_block2 - 1));
mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias);
}
if (brg.with_scales) {
mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]);
sub(reg_aux_scales, scales_offset(ld_block2 - 1));
mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
}
if (brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
sub(reg_aux_zp_comp_a, zp_comp_a_offset(ld_block2 - 1));
mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
}
if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
sub(reg_aux_zp_c_values, zp_c_values_offset(ld_block2 - 1));
mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values);
}
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::advance_bdb_post_op_regs(int adj_bd_block) {
if (brg.zp_type_b != brgemm_broadcast_t::none) {
mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]);
add(reg_aux_zp_comp_b, bdb_zp_comp_b_offset(1));
mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_aux_zp_comp_b);
}
if (brg.req_comp_pads_with_bcast
&& brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
add(reg_aux_zp_comp_a, bdb_compensation_offset(1));
mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
}
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::restore_bdb_post_op_regs(int bd_block2) {
bool post_processed = false;
if (bd_block2 > 1) {
if (brg.zp_type_b != brgemm_broadcast_t::none) {
post_processed = true;
mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]);
sub(reg_aux_zp_comp_b, bdb_zp_comp_b_offset(bd_block2 - 1));
mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_aux_zp_comp_b);
}
if (brg.req_comp_pads_with_bcast
&& brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
sub(reg_aux_zp_comp_a, bdb_compensation_offset(bd_block2 - 1));
mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
}
}
if (post_processed) mov(reg_buf, ptr[rsp + reg_buf_offs_]);
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::ldb_regs_shift(int ld_block2, bool is_tail) {
int C_offset = (is_tail) ? ldb_C_offset(1, true) : ldb_C_offset(ld_block2);
int D_offset = (is_tail) ? ldb_D_offset(1, true) : ldb_D_offset(ld_block2);
add(reg_aux_C, C_offset);
add(reg_aux_D, D_offset);
add(reg_b_offset,
(is_tail) ? ldb_B_offset(1, true) : ldb_B_offset(ld_block2));
if (brg.with_bias) {
mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]);
add(reg_aux_bias,
(is_tail) ? bias_offset(1, true) : bias_offset(ld_block2));
mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias);
}
if (brg.req_s8s8_compensation) {
mov(reg_aux_compensation, ptr[rsp + reg_aux_comp_offs_]);
add(reg_aux_compensation,
(is_tail) ? compensations_offset(1, true)
: compensations_offset(ld_block2));
mov(ptr[rsp + reg_aux_comp_offs_], reg_aux_compensation);
}
if (brg.with_scales) {
mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]);
add(reg_aux_scales,
(is_tail) ? scales_offset(1, true) : scales_offset(ld_block2));
mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
}
if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) {
mov(reg_aux_wei_scales, ptr[rsp + reg_aux_wei_scales_offs_]);
add(reg_aux_wei_scales, (is_tail) ? wei_scales_offset(1, true) : wei_scales_offset(ld_block2));
mov(ptr[rsp + reg_aux_wei_scales_offs_], reg_aux_wei_scales);
mov(ptr[rsp + reg_aux2_wei_scales_offs_], reg_aux_wei_scales);
}
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
mov(reg_aux_wei_zp, ptr[rsp + reg_aux_wei_zero_points_offs_]);
add(reg_aux_wei_zp, (is_tail) ? wei_zp_offset(1, true) : wei_zp_offset(ld_block2));
mov(ptr[rsp + reg_aux_wei_zero_points_offs_], reg_aux_wei_zp);
mov(ptr[rsp + reg_aux2_wei_zero_points_offs_], reg_aux_wei_zp);
}
if (brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
add(reg_aux_zp_comp_a,
(is_tail) ? zp_comp_a_offset(1, true)
: zp_comp_a_offset(ld_block2));
mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
}
if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
add(reg_aux_zp_c_values,
(is_tail) ? zp_c_values_offset(1, true)
: zp_c_values_offset(ld_block2));
mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values);
}
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::advance_bd_block2_post_op_regs(int bd_block2) {
if (brg.req_comp_pads_with_bcast && brg.req_s8s8_compensation) {
mov(reg_compensation, ptr[rsp + reg_comp_offs_]);
add(reg_compensation, bdb_compensation_offset(bd_block2));
mov(ptr[rsp + reg_comp_offs_], reg_compensation);
}
if (brg.req_comp_pads_with_bcast
&& brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_zp_comp_a, ptr[rsp + reg_zp_comp_a_offs_]);
add(reg_zp_comp_a, bdb_zp_comp_a_offset(bd_block2));
mov(ptr[rsp + reg_zp_comp_a_offs_], reg_zp_comp_a);
}
if (brg.zp_type_b != brgemm_broadcast_t::none) {
mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]);
add(reg_zp_comp_b, bdb_zp_comp_b_offset(bd_block2));
mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b);
}
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::copy_post_ops_stack_values_to_aux(
bool is_reg_tail) {
if (!is_reg_tail) {
mov(reg_aux_C, reg_C);
mov(reg_aux_D, reg_D);
xor_(reg_b_offset, reg_b_offset);
if (brg.with_bias) {
mov(reg_bias, ptr[rsp + reg_bias_offs_]);
mov(ptr[rsp + reg_aux_bias_offs_], reg_bias);
}
if (brg.req_s8s8_compensation) {
mov(reg_compensation, ptr[rsp + reg_comp_offs_]);
mov(ptr[rsp + reg_aux_comp_offs_], reg_compensation);
}
if (brg.with_scales) {
mov(reg_scales, ptr[rsp + reg_scales_offs_]);
mov(ptr[rsp + reg_aux_scales_offs_], reg_scales);
}
if (brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_zp_comp_a, ptr[rsp + reg_zp_comp_a_offs_]);
mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_zp_comp_a);
}
if (brg.zp_type_c != brgemm_broadcast_t::none) {
mov(reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]);
mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_zp_c_values);
}
if (brg.with_wei_decomp_scales) {
mov(reg_wei_scales, ptr[rsp + reg_wei_scales_offs_]);
mov(ptr[rsp + reg_aux_wei_scales_offs_], reg_wei_scales);
mov(ptr[rsp + reg_aux2_wei_scales_offs_], reg_wei_scales);
}
if (brg.with_wei_decomp_zero_points) {
mov(reg_wei_zp, ptr[rsp + reg_wei_zero_points_offs_]);
mov(ptr[rsp + reg_aux_wei_zero_points_offs_], reg_wei_zp);
mov(ptr[rsp + reg_aux2_wei_zero_points_offs_], reg_wei_zp);
}
}
if (brg.with_src_dyn_quant) {
mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
mov(ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales);
mov(ptr[rsp + reg_aux2_src_scales_offs_], reg_src_scales);
if (brg.with_wei_decomp_zero_points) {
mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]);
mov(ptr[rsp + reg_aux_src_grouped_sum_offs_], reg_src_grouped_sum);
mov(ptr[rsp + reg_aux2_src_grouped_sum_offs_], reg_src_grouped_sum);
}
}
if (brg.zp_type_b != brgemm_broadcast_t::none) {
mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]);
mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_zp_comp_b);
}
}
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::read_params() {
Label label_done;
if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1);
if (brg.type == brgemm_addr) {
mov(reg_addr_batch, ptr[param1 + GET_OFF(batch)]);
} else {
if (brg.layout == brgemm_row_major) {
mov(reg_A, ptr[param1 + GET_OFF(ptr_A)]);
mov(reg_B, ptr[param1 + GET_OFF(ptr_B)]);
} else {
mov(reg_A, ptr[param1 + GET_OFF(ptr_B)]);
mov(reg_B, ptr[param1 + GET_OFF(ptr_A)]);
}