-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathjit_brgemm_amx_uker.cpp
2772 lines (2368 loc) · 102 KB
/
jit_brgemm_amx_uker.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*******************************************************************************
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include "common/c_types_map.hpp"
#include "common/nstl.hpp"
#include "common/type_helpers.hpp"
#include "common/utils.hpp"
#include "cpu/platform.hpp"
#include "cpu/x64/brgemm/brgemm.hpp"
#include "cpu/x64/brgemm/brgemm_types.hpp"
#include "cpu/x64/cpu_isa_traits.hpp"
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
#include "cpu/x64/jit_avx512_core_fp8cvt.hpp"
#define GET_OFF(field) offsetof(brgemm_kernel_params_t, field)
#define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field)
namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
using namespace dnnl::impl::utils;
using namespace Xbyak;
struct jit_brgemm_amx_uker_base_t : public jit_base_brgemm_kernel_t {
jit_brgemm_amx_uker_base_t(const brgemm_desc_t &abrg)
: jit_base_brgemm_kernel_t(jit_name(), abrg.isa_impl)
, brg(abrg)
, postops_injector_(nullptr) {
bool has_f8_e5m2_binary_postops = false;
bool has_f8_e4m3_binary_postops = false;
if (brg.with_binary) {
const auto &post_ops = brg.attr()->post_ops_;
for (int i = 0; i < post_ops.len(); i++) {
const auto &entry = post_ops.entry_[i];
if (!entry.is_binary()) continue;
has_f8_e5m2_binary_postops
= entry.binary.src1_desc.data_type == data_type::f8_e5m2
|| has_f8_e5m2_binary_postops;
has_f8_e4m3_binary_postops
= entry.binary.src1_desc.data_type == data_type::f8_e4m3
|| has_f8_e4m3_binary_postops;
}
}
if (brg.is_fp8_via_convert() || has_f8_e5m2_binary_postops
|| has_f8_e4m3_binary_postops) {
if (one_of(data_type::f8_e5m2, brg.dt_a, brg.dt_b, brg.dt_d)
|| has_f8_e5m2_binary_postops)
f8_e5m2_emulator_ = utils::make_unique<fp8_emulation_e5m2_t>(
this, fp8_emu_xmm_1(), fp8_emu_xmm_2(), fp8_emu_xmm_3(),
fp8_tmp_mask, fp8_tmp_reg);
if (one_of(data_type::f8_e4m3, brg.dt_a, brg.dt_b, brg.dt_d)
|| has_f8_e4m3_binary_postops)
f8_e4m3_emulator_ = utils::make_unique<fp8_emulation_e4m3_t>(
this, fp8_emu_xmm_1(), fp8_emu_xmm_2(), fp8_emu_xmm_3(),
fp8_emu_xmm_4(), fp8_emu_xmm_5(), fp8_tmp_reg);
}
if (brg.with_eltwise || brg.with_binary || brg.with_sum) {
static constexpr bool preserve_gpr = true;
// we don't use zmm1 for storing vectors
// so we don't need to preserve vmm
static constexpr bool preserve_vmm = false;
static constexpr bool use_exact_tail_scalar_bcast = false;
const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md());
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Xbyak::Zmm(1).getIdx()), this->r14,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
ld_tail_mask, use_exact_tail_scalar_bcast};
const binary_injector::static_params_t bsp(this->param1,
binary_injector::get_all_strategies_supported_by_injector(),
rhs_sp, f8_e5m2_emulator_.get(), f8_e4m3_emulator_.get());
eltwise_injector::static_params_t esp;
esp.preserve_vmm = preserve_vmm;
esp.preserve_p_table = false;
auto st = safe_ptr_assign(postops_injector_,
po_injector_t::create(this, brg.isa_impl,
brg.attr()->post_ops_, bsp, esp));
if (st != status::success) {
assert(!"postops_injector creation failed");
}
using namespace dnnl::impl::cpu::binary_injector_utils;
std::tie(with_binary_per_oc_bcast_, with_binary_per_oc_sp_bcast_,
with_binary_per_mb_bcast_, with_binary_channel_bcast_,
with_binary_per_mb_w_bcast_, with_binary_per_w_bcast_,
with_binary_batch_bcast_, with_binary_spatial_bcast_,
with_binary_no_bcast_)
= bcast_strategies_present_tup(brg.attr()->post_ops_.entry_,
dst_md_wrapper, broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::per_w,
broadcasting_strategy_t::batch,
broadcasting_strategy_t::spatial,
broadcasting_strategy_t::no_broadcast);
handle_binary_po_offset_ = with_binary_per_oc_bcast_
|| with_binary_per_oc_sp_bcast_ || with_binary_per_mb_bcast_
|| with_binary_channel_bcast_ || with_binary_per_mb_w_bcast_
|| with_binary_per_w_bcast_ || with_binary_batch_bcast_
|| with_binary_spatial_bcast_ || with_binary_no_bcast_;
}
use_ils_ = brg.brgattr.use_interleave_stores;
}
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_amx_uker_base_t)
brgemm_desc_t brg;
const brgemm_desc_t &get_brg() const override { return brg; }
private:
using po_injector_t = injector::jit_uni_postops_injector_base_t<Zmm>;
std::unique_ptr<po_injector_t> postops_injector_;
std::unique_ptr<fp8_emulation_e5m2_t> f8_e5m2_emulator_;
std::unique_ptr<fp8_emulation_e4m3_t> f8_e4m3_emulator_;
using reg64_t = const Xbyak::Reg64;
enum {
simd_w = 16,
zmm_width_in_bytes = cpu_isa_traits_t<avx512_core>::vlen,
};
// Register decomposition
const reg64_t param1 = abi_param1;
const reg64_t reg_iter_label = r9;
const reg64_t reg_iter_labels_list = rax;
const reg64_t reg_addr_batch = r13;
const reg64_t reg_aux1_batch = rbp;
const reg64_t reg_A = r11;
const reg64_t reg_B = r10;
const reg64_t reg_stride_lda = r14;
const reg64_t reg_stride_ldb = abi_not_param1;
const reg64_t reg_C = r15;
const reg64_t reg_D = r12;
const reg64_t reg_buf = r8;
const reg64_t reg_BS = rbx;
const reg64_t reg_BS_loop = r9;
const reg64_t reg_bias = rbx;
const reg64_t reg_scales = rbx;
const reg64_t reg_dst_scales = rbx;
const reg64_t reg_stride_ld_block = rdx;
const reg64_t reg_do_post_ops = rbx;
const reg64_t reg_do_skip_accum = reg_do_post_ops;
const reg64_t reg_tmp_gpr = rbx;
const reg64_t reg_ptr_sum_scale = rbx;
const reg64_t reg_zp_comp_a = rbx;
const reg64_t reg_aux_zp_comp_a = rbx;
const reg64_t reg_zp_a_values = rbx;
const reg64_t reg_zp_comp_b = rbx;
const reg64_t reg_zp_c_values = rbx;
const reg64_t reg_ptr_sum_zp = rbx;
const reg64_t reg_converted_stride = rsi;
const reg64_t reg_zp_comp_pad_a = rsi;
constexpr static int abi_param1_offs_ = 0;
constexpr static int reg_zp_comp_a_offs_ = 8;
constexpr static int reg_zp_comp_b_offs_ = 16;
constexpr static int reg_zp_c_values_offs_ = 24;
constexpr static int reg_iter_labels_list_offs_ = 32;
constexpr static int reg_zp_a_values_offs_ = 40;
constexpr static int stack_space_needed_ = 48;
bool are_post_ops_applicable_ = false;
bool need_to_apply_alpha_beta_ = false;
bool may_load_accumulators_ = false;
bool handle_binary_po_offset_ = false;
bool with_binary_per_oc_bcast_ = false;
bool with_binary_per_oc_sp_bcast_ = false;
bool with_binary_channel_bcast_ = false;
bool with_binary_per_mb_bcast_ = false;
bool with_binary_per_mb_w_bcast_ = false;
bool with_binary_per_w_bcast_ = false;
bool with_binary_batch_bcast_ = false;
bool with_binary_spatial_bcast_ = false;
bool with_binary_no_bcast_ = false;
bool prepare_post_ops_registers_once_ = false;
const char *bd_mask_buffer_ptr_ = nullptr;
std::vector<size_t> adj_bd_mask_buffer_;
std::vector<size_t> skipped_bd_mask_buffer_;
palette_config_t palette_;
// used to store offsets within wsp buffer where the data is
// transformed(downconverted), to reuse when needed.
std::unordered_map<std::string, size_t> transform_buf_map_A_;
std::unordered_map<std::string, size_t> transform_buf_map_B_;
size_t LDA_size_ = 0, LDA2_size_ = 0;
size_t LDB_size_ = 0, LDB2_size_ = 0;
size_t LDC_size_ = 0, LDC2_size_M_ = 0, LDC2_size_N_ = 0;
size_t LDD_size_ = 0;
size_t ld_block_B_size_ = 0;
size_t ld_block_C_size_ = 0;
size_t ld_block_D_size_ = 0;
size_t ld_block_bias_size_ = 0;
size_t ld_block_scales_size_ = 0;
size_t ld_block_zp_size_ = 0;
size_t ldb_tail_B_size_ = 0;
size_t ldb_tail_C_size_ = 0;
size_t ldb_tail_D_size_ = 0;
size_t ldb_tail_zp_size_ = 0;
enum matrix_kind_t { matrix_A, matrix_B, matrix_C, matrix_D };
// Loops in brgemm kernel are (two outermost loops depend on loop order):
// by bd block2
// by ld block2
// by batch_size
// by rd block
// gemm_microkernel
// Structures below (iteration_block_t, dim_iteration_t, bs_iteration_t and
// iteration_map_t) describe the structure of cycles and are used for
// JIT code generation
struct iteration_block_t {
int block = 0;
size_t pos = 0;
bool is_tail = false;
iteration_block_t(size_t pos_, int block_, bool is_tail_ = false)
: block(block_), pos(pos_), is_tail(is_tail_) {}
bool operator==(const iteration_block_t &rhs) const {
return block == rhs.block && is_tail == rhs.is_tail;
}
};
struct dim_iteration_t {
size_t idx = 0;
std::vector<iteration_block_t> blocks;
virtual bool operator==(const dim_iteration_t &rhs) const {
return blocks == rhs.blocks;
}
virtual bool operator!=(const dim_iteration_t &rhs) const {
return !operator==(rhs);
}
size_t pos(size_t b) const {
assert(b < blocks.size());
return blocks[b].pos;
}
size_t rel_pos(size_t b) const {
assert(b < blocks.size());
return (blocks[b].pos - blocks[0].pos);
}
int block(size_t b) const {
assert(b < blocks.size());
return blocks[b].block;
}
bool is_tail(size_t b) const {
assert(b < blocks.size());
return blocks[b].is_tail;
}
int block2() const { return static_cast<int>(blocks.size()); }
int length() const {
if (blocks.empty()) return 0;
auto n = blocks.size();
// only last block may be different
return ((n - 1) * blocks[0].block + blocks[n - 1].block);
}
dim_iteration_t() = default;
virtual ~dim_iteration_t() = default;
};
struct bd_iteration_t : public dim_iteration_t {
size_t A_shift {0};
size_t C_shift {0};
size_t D_shift {0};
size_t zp_comp_pad_a_shift {0};
std::vector<char> bd_mask;
std::vector<size_t> adj_bd_mask;
bd_iteration_t *similar {nullptr};
Label lstart;
bool operator==(const dim_iteration_t &_rhs) const override {
// `downcast` will catch a type mismatch in debug mode.
// Note: it supports only a pointer type so far.
const bd_iteration_t &rhs
= *utils::downcast<const bd_iteration_t *>(&_rhs);
bool res = dim_iteration_t::operator==(rhs)
&& A_shift == rhs.A_shift && C_shift == rhs.C_shift
&& D_shift == rhs.D_shift && bd_mask == rhs.bd_mask
&& zp_comp_pad_a_shift == rhs.zp_comp_pad_a_shift;
return res;
}
bool operator!=(const dim_iteration_t &_rhs) const override {
return !operator==(_rhs);
}
};
struct bs_iteration_t {
size_t idx = 0;
size_t pos = 0;
bool is_first = false;
bool is_last = false;
bs_iteration_t() = default;
bs_iteration_t(
size_t pos_, bool is_first_ = true, bool is_last_ = false)
: pos(pos_), is_first(is_first_), is_last(is_last_) {}
};
class iteration_map_t {
public:
struct top_loop_t {
std::vector<dim_iteration_t> ldis;
std::vector<bd_iteration_t> bdis;
std::vector<bs_iteration_t> bsis;
std::vector<dim_iteration_t> rdis;
int duplicated {0};
bool is_last_rdi(const dim_iteration_t *rdi) const {
return (rdi->idx == rdis.size() - 1);
}
};
iteration_map_t() : tloops(2) {}
inline top_loop_t &operator[](bool bidx) {
return tloops[static_cast<int>(bidx)];
}
inline const top_loop_t &operator[](bool bidx) const {
return tloops[static_cast<int>(bidx)];
}
private:
std::vector<top_loop_t> tloops;
};
struct brgemm_iteration_t {
const bd_iteration_t *bdi {nullptr};
const dim_iteration_t *ldi {nullptr};
const bs_iteration_t *bsi {nullptr};
const dim_iteration_t *rdi {nullptr};
bool apply_postops {false};
bool skip_accumulation {false};
bool first_bsi {false};
bool last_bsi {false};
brgemm_iteration_t() = default;
};
struct prf_t {
brgemm_kernel_prefetching_t pft = brgemm_prf_default;
int dist = -1;
int vec = 0;
void set(brgemm_kernel_prefetching_t pft_, int dist_) {
pft = pft_;
dist = dist_;
vec = 0;
}
void reset() { vec = 0; }
};
// iteration map
iteration_map_t imap_;
// interleave stores
bool use_ils_ = false;
bool was_prev_bi_ = false;
// saved parameters for storing
brgemm_iteration_t prev_bi_;
// current storing coordinates
int ils_vec_ = 0, ils_bdb_ = 0, ils_ldb_ = 0, ils_bd_start_ = 0;
int ils_bd_step_ = 3; // heuristic value
prf_t prf0A, prf1A, prf2A, prfntaA, prf0B, prf1B, prf2B, prfntaB, prf0C,
prf1C;
bool dt_requires_saturation_ = false;
bool ununroll_bd_loop = false;
Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2);
Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3);
Xbyak::Opmask fp_col_mask = Xbyak::Opmask(4);
Xbyak::Opmask rd_tail_mask = Xbyak::Opmask(5);
// Zmm map below
const Xbyak::Zmm &zmm_tmp_1() const noexcept { return this->zmm0; }
const Xbyak::Zmm &zmm_tmp_2() const noexcept { return this->zmm1; }
const Xbyak::Zmm &zmm_tmp_3() const noexcept { return this->zmm2; }
/* fp8 emulation */
Xmm fp8_emu_xmm_1() const noexcept { return Xmm(1); }
Xmm fp8_emu_xmm_2() const noexcept { return Xmm(2); }
Xmm fp8_emu_xmm_3() const noexcept { return Xmm(3); }
Xmm fp8_emu_xmm_4() const noexcept { return Xmm(6); }
Xmm fp8_emu_xmm_5() const noexcept { return Xmm(7); }
Xbyak::Opmask fp8_tmp_mask = Xbyak::Opmask(6);
const reg64_t fp8_tmp_reg = rax;
const Xbyak::Zmm zmm_bf32_permute = zmm6;
const Xbyak::Zmm zmm_zp_comp_a = zmm6;
const Xbyak::Zmm zmm_zp_c = zmm7;
const Xbyak::Zmm zmm_lbound = zmm8;
const Xbyak::Zmm zmm_ubound = zmm9;
// zmm_bias, zmm_bias and accm shouldn't be overlapped
Xbyak::Zmm accm(int bd) const {
assert(bd < 16);
return Xbyak::Zmm(31 - (bd % ils_bd_step_));
}
Xbyak::Zmm zmm_bias(int ldb) const {
assert(ldb < 5);
// zmm10 - zmm14
return Xbyak::Zmm(10 + ldb);
}
Xbyak::Zmm zmm_scales(int ldb) const {
assert(ldb < 5);
assert(ils_bd_step_ < 10);
// zmm15 - zmm19
return Xbyak::Zmm(15 + ldb);
}
Xbyak::Zmm zmm_mask(const Xbyak::Zmm &zmm_in, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask) const;
Xbyak::Ymm ymm_mask(const Xbyak::Ymm &ymm_in, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask) const;
Xbyak::Xmm xmm_mask(const Xbyak::Xmm &xmm_in, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask) const;
void cvt2ps(data_type_t type_in, const Xbyak::Zmm &zmm_in,
const Xbyak::Operand &op, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask);
void read_params();
void load_accumulators(brgemm_iteration_t &bi);
void maybe_saturation(Xbyak::Zmm &zmm);
void apply_alpha_beta_to_vector(
int idx, const Address &addr, bool is_ld_tail);
void apply_post_ops_to_range(brgemm_iteration_t &bi, int bd_start,
int bd_finish, int bdb, int ldb);
void store_vector_with_post_ops(
int idx, const Address &addr, bool is_ld_tail);
void prepare_post_ops_registers_ldb(brgemm_iteration_t &bi, int ldb);
void prepare_post_ops_registers(brgemm_iteration_t &bi);
bool bi_shift_output(
brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi);
bool bi_shift_A(
brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi);
bool bi_shift_B(
brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi);
void uni_prefetch(const Address &addr, brgemm_kernel_prefetching_t pft,
bool for_write);
void prefetch_CD_range(brgemm_iteration_t &bi,
brgemm_kernel_prefetching_t pft, int bd_start, int bd_finish,
int bdb, int ldb);
int calc_ops_CD(brgemm_iteration_t &bi) const noexcept;
void prefetch_CD(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi,
prf_t &prf, bool prefetch_all);
void prefetch_A(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi,
prf_t &prf, bool prefetch_all);
void prefetch_B(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi,
prf_t &prf, bool prefetch_all);
void prefetching(brgemm_iteration_t &bi, bool prefetch_all);
void process_output_range(brgemm_iteration_t &bi, int bd_start,
int bd_finish, int bdb, int ldb);
void store_vector_without_post_ops(
int idx, const Address &addr, bool is_ld_tail);
void store_vector(brgemm_iteration_t &bi, int bdb, int bd, int ldb);
void apply_comp_pad_to_vector(brgemm_iteration_t &bi, int bdb, int inp_bd,
int ldb, const int idx);
void interleave_store(brgemm_iteration_t &bi, bool store_all);
void store_accumulators(brgemm_iteration_t &bi);
void set_A_B_matrices(int bs);
void set_A_B_matrices();
void bf32_downconvert(brgemm_iteration_t &bi, int num_rows,
int tile_num_col_bytes, reg64_t reg_data, int offset,
reg64_t reg_data_stride, reg64_t reg_buf);
void fp8_to_f16_upconvert(brgemm_iteration_t &bi, int num_rows,
int tile_num_col_bytes, reg64_t reg_data, int offset,
reg64_t reg_data_stride, reg64_t reg_buf, data_type_t dt);
void fp8_to_f16_upconvert_to_vnni(brgemm_iteration_t &bi, int num_rows,
int tile_num_col_bytes, reg64_t reg_data, int offset,
reg64_t reg_data_stride, reg64_t reg_buf, data_type_t dt);
void bf32_downconvert_to_vnni(brgemm_iteration_t &bi, int num_rows,
int tile_num_col_bytes, reg64_t reg_data, int offset,
reg64_t reg_data_stride, reg64_t reg_buf);
void maybe_pre_process_data(brgemm_iteration_t &bi, const Tmm &t1,
reg64_t reg_base, size_t offset, reg64_t reg_stride,
matrix_kind_t mk);
bool maybe_pre_process_k_tail(brgemm_iteration_t &bi, int bdb,
const Tmm &t1, reg64_t reg_base, size_t offset, reg64_t reg_stride,
matrix_kind_t mk);
void maybe_tileloadd_nt(
brgemm_iteration_t &bi, matrix_kind_t mk, int xdb, size_t offset);
void tdpbxxd(brgemm_iteration_t &bi, int bdb_idx, int ldb_idx,
bool do_pre_tilestore, bool do_post_tilestore);
void gemm_microkernel_amx(brgemm_iteration_t &bi);
void rdb_loop(brgemm_iteration_t &bi);
void bs_loop_body(brgemm_iteration_t &bi);
void bs_loop(brgemm_iteration_t &bi);
void ldb_loop_body(brgemm_iteration_t &bi);
void ldb_loop(brgemm_iteration_t &bi);
void bdb_loop_body(brgemm_iteration_t &bi);
void bdb_loop(brgemm_iteration_t &bi);
void init(brgemm_iteration_t &bi);
void generate() override;
void prepare_bd_mask() noexcept;
int skipped_bd_mask(int inp_bd) noexcept;
bool get_store_by_vectors(bool apply_post_ops) const {
const bool need_to_apply_post_ops
= are_post_ops_applicable_ && apply_post_ops;
const auto store_by_vectors = need_to_apply_alpha_beta_
|| need_to_apply_post_ops || brg.brgattr.bd_mask_level;
return store_by_vectors;
}
bool actual_ils(bool apply_post_ops, bool skip_accumulation = false) const {
return (use_ils_ && get_store_by_vectors(apply_post_ops)
&& !skip_accumulation);
}
size_t A_offset(const brgemm_iteration_t &bi, int bdb) const noexcept;
size_t B_offset(const brgemm_iteration_t &bi, int ldb) const noexcept;
size_t C_offset(const brgemm_iteration_t &bi, int bdb, int inp_bd,
int ldb) const noexcept;
size_t D_offset(const brgemm_iteration_t &bi, int bdb, int inp_bd,
int ldb) const noexcept;
size_t lda() const noexcept;
size_t ldb() const noexcept;
size_t bias_offset(int ldb) const noexcept;
size_t scales_offset(int ldb) const noexcept;
size_t zp_comp_a_offset(int ldb) const noexcept;
size_t zp_comp_pad_a_offset(const brgemm_iteration_t &bi, int bdb,
int inp_bd, int ldb) const noexcept;
size_t zp_comp_b_offset(int bd) const noexcept;
size_t zp_c_values_offset(brgemm_iteration_t &bi, int ldb) const noexcept;
bool is_out_bd(const bd_iteration_t *bdi, int bdb, int inp_bd) const;
int get_out_bd(const bd_iteration_t *bdi, int bdb, int inp_bd) const;
void maybe_tilestore(brgemm_iteration_t &bi, int bdb_idx, int ldb_idx,
bool do_pre_tilestore, bool do_post_tilestore);
int get_C_tensor(brgemm_iteration_t &bi, int m, int n) const noexcept;
void top_loop(brgemm_iteration_t &bi);
bd_iteration_t *find_similar(const bd_iteration_t *bdi, bool apply_postops);
void fill_imap();
};
bool jit_brgemm_amx_uker_base_t::bi_shift_output(
brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) {
res_bi = bi;
if (shift == 0) return true;
const auto &tloop = imap_[bi.apply_postops];
const auto nldis = tloop.ldis.size();
const auto nbdis = tloop.bdis.size();
size_t lidx = 0;
size_t bd_idx = 0;
size_t ld_idx = 0;
if (brg.innermost_loop == brgemm_ld_loop_innermost) {
lidx = bi.bdi->idx * nldis + bi.ldi->idx;
lidx += shift;
bd_idx = lidx / nldis;
ld_idx = lidx % nldis;
} else if (brg.innermost_loop == brgemm_bd_loop_innermost) {
lidx = bi.ldi->idx * nbdis + bi.bdi->idx;
lidx += shift;
ld_idx = lidx / nbdis;
bd_idx = lidx % nbdis;
} else
assert(!"Unknown loop order!");
if (lidx >= nldis * nbdis) return false;
res_bi.bdi = &(tloop.bdis[bd_idx]);
res_bi.ldi = &(tloop.ldis[ld_idx]);
return true;
}
bool jit_brgemm_amx_uker_base_t::bi_shift_A(
brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) {
res_bi = bi;
const auto &tloop = imap_[bi.apply_postops];
const auto nbdis = tloop.bdis.size();
const auto nrdis = tloop.rdis.size();
auto lidx = bi.bdi->idx * nrdis + bi.rdi->idx;
lidx += shift;
if (lidx >= nrdis * nbdis) return false;
const auto bd_idx = lidx / nrdis;
const auto rd_idx = lidx % nrdis;
res_bi.bdi = &(tloop.bdis[bd_idx]);
res_bi.rdi = &(tloop.rdis[rd_idx]);
return true;
}
bool jit_brgemm_amx_uker_base_t::bi_shift_B(
brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) {
res_bi = bi;
const auto &tloop = imap_[bi.apply_postops];
const auto nldis = tloop.ldis.size();
const auto nrdis = tloop.rdis.size();
auto lidx = bi.ldi->idx * nrdis + bi.rdi->idx;
lidx += shift;
if (lidx >= nrdis * nldis) return false;
const auto ld_idx = lidx / nrdis;
const auto rd_idx = lidx % nrdis;
res_bi.ldi = &(tloop.ldis[ld_idx]);
res_bi.rdi = &(tloop.rdis[rd_idx]);
return true;
}
int jit_brgemm_amx_uker_base_t::get_C_tensor(
brgemm_iteration_t &bi, int m, int n) const noexcept {
return brg.get_C_tensor(m, n, bi.bdi->is_tail(m), bi.ldi->is_tail(n));
}
void jit_brgemm_amx_uker_base_t::prepare_bd_mask() noexcept {
if (!brg.brgattr.bd_mask_level) return;
bd_mask_buffer_ptr_ = brg.brgattr.bd_mask;
const auto bd_mask_size = brg.bcast_dim;
adj_bd_mask_buffer_.resize(bd_mask_size);
skipped_bd_mask_buffer_.resize(bd_mask_size);
if (bd_mask_buffer_ptr_ != nullptr) {
int out_ibd = 0;
for (int i = 0; i < bd_mask_size; i++) {
adj_bd_mask_buffer_[i] = out_ibd;
out_ibd += bd_mask_buffer_ptr_[i];
skipped_bd_mask_buffer_[i] = i;
for (auto ii = i; ii < bd_mask_size; ii++) {
if (bd_mask_buffer_ptr_[ii]) {
skipped_bd_mask_buffer_[i] = ii;
break;
}
}
}
} else
assert(!"struct nullptr error");
}
int jit_brgemm_amx_uker_base_t::skipped_bd_mask(int inp_bd) noexcept {
if (brg.brgattr.bd_mask_level != 2)
return inp_bd;
else
return skipped_bd_mask_buffer_[inp_bd];
}
size_t jit_brgemm_amx_uker_base_t::A_offset(
const brgemm_iteration_t &bi, int bdb) const noexcept {
const auto bs_offs = (brg.type == brgemm_static_offs)
? brg.brgattr.static_offsets[bi.bsi->idx].offset.A
: 0;
const auto bdb_offs
= ununroll_bd_loop ? bi.bdi->rel_pos(bdb) : bi.bdi->pos(bdb);
return bdb_offs * LDA2_size_ + bs_offs
+ bi.rdi->pos(0) * brg.rd_block * brg.typesize_A;
}
size_t jit_brgemm_amx_uker_base_t::B_offset(
const brgemm_iteration_t &bi, int ldb) const noexcept {
const auto bs_offs = (brg.type == brgemm_static_offs)
? brg.brgattr.static_offsets[bi.bsi->idx].offset.B
: 0;
const auto rdb_B_offset = bi.rdi->pos(0) * brg.rd_block * LDB_size_;
const auto ldb_offs = bi.ldi->pos(ldb) * brg.ld_block;
const auto ldb_B_offset = brg.typesize_B
* ((ldb_offs / brg.LDB) * brg.brgattr.LDB2
+ (ldb_offs % brg.LDB) * brg.rd_step);
return rdb_B_offset + ldb_B_offset + bs_offs;
}
size_t jit_brgemm_amx_uker_base_t::C_offset(const brgemm_iteration_t &bi,
int bdb, int inp_bd, int ldb) const noexcept {
const auto bi_bd_start = get_out_bd(bi.bdi, 0, 0);
const auto bd = get_out_bd(bi.bdi, bdb, inp_bd);
const auto bd_shift = bd - (ununroll_bd_loop ? bi_bd_start : 0);
size_t ldc_elem = (size_t)ldb * brg.ld_block;
size_t bloc_idx = ldc_elem / brg.LDC;
size_t in_block = ldc_elem % brg.LDC;
return (size_t)bd_shift * LDC2_size_M_ + (size_t)bloc_idx * LDC2_size_N_
+ in_block * brg.typesize_C;
}
size_t jit_brgemm_amx_uker_base_t::D_offset(const brgemm_iteration_t &bi,
int bdb, int inp_bd, int ldb) const noexcept {
const auto bi_bd_start = get_out_bd(bi.bdi, 0, 0);
const auto bd = get_out_bd(bi.bdi, bdb, inp_bd);
const auto bd_shift = bd - (ununroll_bd_loop ? bi_bd_start : 0);
return (size_t)bd_shift * LDD_size_ + (size_t)ldb * ld_block_D_size_;
}
size_t jit_brgemm_amx_uker_base_t::lda() const noexcept {
return LDA_size_;
}
size_t jit_brgemm_amx_uker_base_t::ldb() const noexcept {
return LDB_size_ * brg.rd_step;
}
size_t jit_brgemm_amx_uker_base_t::bias_offset(int ldb) const noexcept {
return ldb * ld_block_bias_size_;
}
size_t jit_brgemm_amx_uker_base_t::scales_offset(int ldb) const noexcept {
return brg.is_oc_scale * ldb * ld_block_scales_size_;
}
size_t jit_brgemm_amx_uker_base_t::zp_comp_a_offset(int ldb) const noexcept {
return ldb * ld_block_zp_size_;
}
size_t jit_brgemm_amx_uker_base_t::zp_comp_pad_a_offset(
const brgemm_iteration_t &bi, int bdb, int inp_bd,
int ldb) const noexcept {
const auto bi_bd_start = get_out_bd(bi.bdi, 0, 0);
const auto bd = get_out_bd(bi.bdi, bdb, inp_bd);
const auto bd_shift = bd - (ununroll_bd_loop ? bi_bd_start : 0);
return (size_t)bd_shift * brg.LDB * sizeof(int32_t)
+ (size_t)ldb * ld_block_zp_size_;
}
size_t jit_brgemm_amx_uker_base_t::zp_comp_b_offset(int bd) const noexcept {
return sizeof(int32_t) * bd;
}
size_t jit_brgemm_amx_uker_base_t::zp_c_values_offset(
brgemm_iteration_t &bi, int ldb) const noexcept {
if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
return (bi.ldi->is_tail(ldb)) ? ldb_tail_zp_size_
: bi.ldi->pos(ldb) * ld_block_zp_size_;
}
return 0;
}
bool jit_brgemm_amx_uker_base_t::is_out_bd(
const bd_iteration_t *bdi, int bdb, int inp_bd) const {
const auto bd = bdi->pos(bdb) + inp_bd;
return IMPLICATION(
brg.brgattr.bd_mask_level, bdi->bd_mask[bd - bdi->pos(0)] != 0);
}
int jit_brgemm_amx_uker_base_t::get_out_bd(
const bd_iteration_t *bdi, int bdb, int inp_bd) const {
if (!is_out_bd(bdi, bdb, inp_bd)) return -1;
const auto bd = bdi->pos(bdb) + inp_bd;
if (brg.brgattr.bd_mask_level) {
assert(bdi->adj_bd_mask[bd - bdi->pos(0)] == adj_bd_mask_buffer_[bd]);
return bdi->adj_bd_mask[bd - bdi->pos(0)];
} else
return bd;
}
Xbyak::Zmm jit_brgemm_amx_uker_base_t::zmm_mask(const Xbyak::Zmm &zmm_in,
bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const {
return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
: zmm_in;
}
Xbyak::Ymm jit_brgemm_amx_uker_base_t::ymm_mask(const Xbyak::Ymm &ymm_in,
bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const {
return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
: ymm_in;
}
Xbyak::Xmm jit_brgemm_amx_uker_base_t::xmm_mask(const Xbyak::Xmm &xmm_in,
bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const {
return mask_flag ? (store ? xmm_in | ktail_mask : xmm_in | ktail_mask | T_z)
: xmm_in;
}
void jit_brgemm_amx_uker_base_t::cvt2ps(data_type_t type_in,
const Xbyak::Zmm &zmm_in, const Xbyak::Operand &op, bool mask_flag,
bool store, Xbyak::Opmask ktail_mask) {
const Xbyak::Zmm zmm = zmm_mask(zmm_in, mask_flag, store, ktail_mask);
switch (type_in) {
case data_type::f32:
case data_type::s32: vmovups(zmm, op); break;
case data_type::bf16:
vpmovzxwd(zmm, op);
vpslld(zmm, zmm, 16);
break;
case data_type::f16: vcvtph2ps(zmm, op); break;
case data_type::f8_e5m2:
f8_e5m2_emulator_->vcvt_f8_to_f32(zmm, op);
break;
case data_type::f8_e4m3:
f8_e4m3_emulator_->vcvt_f8_to_f32(zmm, op);
break;
case data_type::s8: vpmovsxbd(zmm, op); break;
case data_type::u8: vpmovzxbd(zmm, op); break;
default: assert(!"unsupported data type");
}
if (types::is_integral_dt(type_in)) vcvtdq2ps(zmm_in, zmm_in);
}
void jit_brgemm_amx_uker_base_t::read_params() {
Label label_done;
mov(reg_BS, ptr[param1 + GET_OFF(BS)]);
mov(reg_addr_batch, ptr[param1 + GET_OFF(batch)]);
mov(reg_buf, ptr[param1 + GET_OFF(ptr_buf)]);
if (brg.zp_type_a != brgemm_broadcast_t::none) {
mov(reg_zp_comp_a, ptr[param1 + GET_OFF(a_zp_compensations)]);
mov(ptr[rsp + reg_zp_comp_a_offs_], reg_zp_comp_a);
mov(reg_zp_a_values, ptr[param1 + GET_OFF(zp_a_val)]);
mov(ptr[rsp + reg_zp_a_values_offs_], reg_zp_a_values);
if (brg.req_comp_pads_with_bcast)
mov(reg_zp_comp_pad_a, ptr[param1 + GET_OFF(a_zp_compensations)]);
}
if (brg.zp_type_b != brgemm_broadcast_t::none) {
mov(reg_zp_comp_b, ptr[param1 + GET_OFF(b_zp_compensations)]);
mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b);
}
if (brg.zp_type_c != brgemm_broadcast_t::none) {
mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]);
mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values);
}
}
void jit_brgemm_amx_uker_base_t::load_accumulators(brgemm_iteration_t &bi) {
size_t ils_shift = 0;
if (may_load_accumulators_) {
mov(reg_stride_ld_block, LDC_size_);
const auto need_ils_shift
= (actual_ils(bi.apply_postops, bi.skip_accumulation)
&& ununroll_bd_loop && bi.ldi->idx == 0);
// if need_ils_shift then we have to add shift to C because reg_C points
// to previous iteration in this case
ils_shift = need_ils_shift ? bi.bdi->C_shift : 0;
}
for_(int bdb = 0; bdb < bi.bdi->block2(); bdb++)
for (int ldb = 0; ldb < bi.ldi->block2(); ldb++) {
if (may_load_accumulators_) {
auto c_offset = C_offset(bi, bdb, 0, bi.ldi->pos(ldb)) + ils_shift;
tileloadd(Tmm(get_C_tensor(bi, bdb, ldb)),
ptr[reg_C + c_offset + reg_stride_ld_block]);
} else {
// call tilezero on very first iteration
if (!brg.interleave_tilestores_
|| everyone_is(0u, bi.bdi->idx, bi.ldi->idx))
tilezero(Tmm(get_C_tensor(bi, bdb, ldb)));
}
}
}
void jit_brgemm_amx_uker_base_t::apply_alpha_beta_to_vector(
int idx, const Address &addr, bool is_ld_tail) {
auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
auto zmm = Zmm(idx);
auto zmm_beta = zmm_tmp_1();
auto zmm_alpha = zmm_tmp_2();
auto zmm_prev_dst = zmm_tmp_3();
const bool apply_alpha = brg.alpha != 1.f;
const bool apply_beta = brg.beta != 0.f;
if (!apply_alpha && !apply_beta) return;
const bool dq2ps_required = brg.is_int8 && (apply_alpha || brg.beta != 1.f);
const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
if (apply_beta && !use_vadd_for_beta) {
mov(reg_tmp_gpr, float2int(static_cast<float>(brg.beta)));
vmovq(Xmm(zmm_beta.getIdx()), reg_tmp_gpr);
vbroadcastss(zmm_beta, Xmm(zmm_beta.getIdx()));
}
if (apply_alpha) {
mov(reg_tmp_gpr, float2int(static_cast<float>(brg.alpha)));
vmovq(Xmm(zmm_alpha.getIdx()), reg_tmp_gpr);
vbroadcastss(zmm_alpha, Xmm(zmm_alpha.getIdx()));
}
if (dq2ps_required) vcvtdq2ps(zmm, zmm);
if (apply_alpha) vmulps(zmm, zmm, zmm_alpha);
if (apply_beta) {
if (use_vadd_for_beta) {
auto zmm_masked = zmm | k_mask | T_z;
if (brg.is_int8)
vpaddd(zmm_masked, zmm, addr);
else
vaddps(zmm_masked, zmm, addr);
} else {
cvt2ps(brg.dt_c, zmm_prev_dst, addr, true, false, k_mask);
vfmadd231ps(zmm, zmm_prev_dst, zmm_beta);
}
}
}
void jit_brgemm_amx_uker_base_t::apply_post_ops_to_range(
brgemm_iteration_t &bi, int bd_start, int bd_finish, int bdb, int ldb) {
binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
const auto ldb_pos = bi.ldi->pos(ldb);
const auto is_ld_tail = bi.ldi->is_tail(ldb);
if (brg.with_binary) {
if (handle_binary_po_offset_) {
for (auto bd = bd_start; bd < bd_finish; bd++) {
// We have no way to tell the injector to skip some vectors.
// Therefore, we must set parameters correctly for all registers.
// TODO: Make it possible to specify "skipped" vectors to injector
const auto idx = accm(bd).getIdx();
if (is_ld_tail) rhs_arg_params.vmm_tail_idx_.emplace(idx);
rhs_arg_params.vmm_idx_to_out_reg.emplace(idx, reg_D);
if (!is_out_bd(bi.bdi, bdb, bd)) continue;
const auto d_offset = D_offset(bi, bdb, bd, ldb_pos);
rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
idx, d_offset);
}
}
}
const auto sum_injector = [&] {
const float *p_sum_scale = &brg.sum_scale;
const int32_t *p_sum_zp = &brg.sum_zp;
const bool p_sum_scale_reg_set = *p_sum_scale != 1.f;
const bool p_sum_zp_reg_set = *p_sum_zp != 0;
{
const auto &zmm_sum_zp = zmm_tmp_2();
if (p_sum_zp_reg_set) {
mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
}
if (p_sum_scale_reg_set)
mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
const auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
const auto zmm_prev_dst = Xbyak::Zmm(0);