Skip to content

Commit 80175f7

Browse files
committed
cpu: matmul: optimise blocking hueristics for brgemm matmul
1 parent b62899e commit 80175f7

7 files changed

+82
-71
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#===============================================================================
1717

18-
build
18+
build*
1919
external
2020
.vs
2121
.vscode

src/cpu/aarch64/brgemm/brgemm_types.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2020-2023 Intel Corporation
3-
* Copyright 2023 FUJITSU LIMITED
3+
* Copyright 2023-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -192,6 +192,7 @@ struct brgemm_t {
192192
int LDB = 0;
193193
int LDC = 0;
194194
int LDD = 0;
195+
195196
// we use two isa_ variables
196197
// isa_user to store the user provided isa value
197198
// isa_impl to store actual implementation. This can change until the kernel

src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp

+31-60
Original file line numberDiff line numberDiff line change
@@ -766,10 +766,38 @@ void jit_brgemm_kernel_t::read_params() {
766766
void jit_brgemm_kernel_t::zero_accumulators(int bd_block2, bool is_bdb_tail,
767767
int ld_block2, bool is_ld_tail, bool skip_accumulation) {
768768
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
769+
const bool need_to_apply_beta = brg.beta != 0.f;
769770
for_(int bd = 0; bd < bd_block; bd++)
770771
for (int ld = 0; ld < ld_block2; ld++) {
771772
auto zmm = accm(ld_block2, bd, ld);
772-
eor(zmm.d, zmm.d, zmm.d);
773+
// This part is moved here from apply_alpha_beta function so that fadd instruction can be avoided.
774+
// This is also required only when K is blocked.
775+
if (need_to_apply_beta) {
776+
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
777+
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
778+
779+
const int offset = C_offset(bd, ld);
780+
781+
int base_offset = 0;
782+
auto x_addr = reg_aux_C;
783+
784+
if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
785+
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
786+
base_offset = offset;
787+
x_addr = reg_tmp_;
788+
}
789+
LD_MUL_VL(ld1w, zmm.s, k_mask, x_addr, offset - base_offset, 4);
790+
791+
const bool need_init_beta_vmm = brg.beta != 1.f;
792+
auto vmm_beta = z_tail_mask();
793+
if (need_init_beta_vmm) {
794+
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
795+
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
796+
dup(vmm_beta.s, wreg_tmp);
797+
fmul(zmm.s, zmm.s, vmm_beta.s);
798+
}
799+
} else
800+
eor(zmm.d, zmm.d, zmm.d);
773801
}
774802
}
775803

@@ -791,57 +819,7 @@ void jit_brgemm_kernel_t::apply_alpha_beta(
791819
if (apply_alpha) { fmul(vmm.s, vmm.s, vmm_alpha.s); }
792820
}
793821

794-
if (brg.beta == 0.f) return;
795-
const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
796-
const bool need_init_beta_vmm = brg.beta != 1.f;
797-
auto vmm_prev_dst = z_tmp_1();
798-
auto vmm_beta = z_tail_mask();
799-
if (need_init_beta_vmm) {
800-
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
801-
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
802-
dup(vmm_beta.s, wreg_tmp);
803-
}
804-
805-
int base_offset = 0;
806-
auto x_addr = reg_aux_C;
807-
for_(int bd = 0; bd < bd_block; bd++)
808-
for (int ld = 0; ld < ld_block2; ld++) {
809-
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
810-
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
811-
auto vmm = accm(ld_block2, bd, ld);
812-
if (use_vadd_for_beta) {
813-
if (brg.is_int8) {
814-
assert(!"unsupported\n");
815-
} else {
816-
ZRegS z_masked = vmm.s;
817-
ZRegS z(vmm.getIdx());
818-
819-
const int offset = C_offset(bd, ld);
820-
821-
if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
822-
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
823-
base_offset = offset;
824-
x_addr = reg_tmp_;
825-
}
826-
LD_MUL_VL(ld1w, vmm_prev_dst.s, k_mask, x_addr,
827-
offset - base_offset, 4);
828-
if (is_ld_tail) {
829-
movprfx(z_masked, k_mask / T_z, z);
830-
fadd(z_masked, k_mask / T_m, vmm_prev_dst.s);
831-
} else {
832-
fadd(z_masked, z_masked, vmm_prev_dst.s);
833-
}
834-
}
835-
} else {
836-
add_imm(X_DEFAULT_ADDR, reg_aux_C, C_offset(bd, ld), X_TMP_0);
837-
ld1w(vmm_prev_dst.s, k_mask / T_z, ptr(X_DEFAULT_ADDR));
838-
if (brg.beta == 1.f) {
839-
fadd(vmm.s, vmm.s, vmm_prev_dst.s);
840-
} else {
841-
fmla(vmm.s, P_ALL_ONE / T_m, vmm_prev_dst.s, vmm_beta.s);
842-
}
843-
}
844-
}
822+
// This part is moved to the function zero_accumulators.
845823
}
846824

847825
void jit_brgemm_kernel_t::apply_post_ops(
@@ -1464,7 +1442,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
14641442
int base_offset = 0;
14651443

14661444
for (int rd = 0; rd < rd_loop; rd += brg.rd_step) {
1467-
int prefetch_count_B = 0;
14681445
for (int ld = 0; ld < ld_block2; ld++) {
14691446
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
14701447
if (brg.dt_b == data_type::f16) {
@@ -1496,13 +1473,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
14961473
broadcast(bcst(), A_offset(bd, rd),
14971474
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
14981475
}
1499-
if (prefetch_count_B < ld_block2) {
1500-
add_imm(X_DEFAULT_ADDR, reg_aux_B,
1501-
B_offset(prefetch_count_B++, rd)
1502-
+ brg.LDB * brg.rd_block * brg.typesize_B,
1503-
X_TMP_0);
1504-
prfm(PLDL1KEEP, ptr(X_DEFAULT_ADDR));
1505-
}
1476+
//The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
15061477
for (int ld = 0; ld < ld_block2; ld++) {
15071478
auto zmm = accm(ld_block2, bd, ld);
15081479
if (is_emdbd) {

src/cpu/aarch64/matmul/brgemm_matmul.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
145145
? (dim_t)bgmmc_.wei_k_blk
146146
: bgmmc_.LDA;
147147
const auto kernel_isa = i_M == max_m_ker_idx - 1 ? backup_isa : isa;
148+
148149
CHECK(brgemm_desc_init(&brg, kernel_isa, bgmmc_.brg_type, bgmmc_.src_dt,
149150
bgmmc_.wei_dt, false, false, brgemm_row_major, alpha, vbeta,
150151
LDA, bgmmc_.LDB, bgmmc_.LDC, vM, vN, vK));

src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init(
8585
matmul_conf_for_reorder_.K = dims[ndims - 2];
8686
matmul_conf_for_reorder_.N = dims[ndims - 1];
8787
matmul_conf_for_reorder_.wei_n_blk = matmul_conf_for_reorder_.N_blk
88-
= matmul_conf_for_reorder_.LDB = matmul::get_default_n_block(otag);
88+
= matmul_conf_for_reorder_.LDB
89+
= matmul::get_default_n_block(otag, matmul_conf_for_reorder_);
8990
matmul_conf_for_reorder_.N_tail
9091
= matmul_conf_for_reorder_.N % matmul_conf_for_reorder_.N_blk;
9192
matmul_conf_for_reorder_.K_blk = 16 * vnni_granularity;

src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp

+43-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3+
* Copyright 2023-2024 FUJITSU LIMITED
34
*
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
@@ -47,7 +48,8 @@ using namespace dnnl::impl::utils;
4748
using namespace data_type;
4849
using namespace format_tag;
4950

50-
int get_default_n_block(format_tag_t matrix_b_tag) {
51+
int get_default_n_block(
52+
format_tag_t matrix_b_tag, brgemm_matmul_conf_t &bgmmc) {
5153
// Note: consider using weights mem_descriptor 'inner_blks' to
5254
// return B's inner block for non-default cases.
5355
switch (matrix_b_tag) {
@@ -75,7 +77,23 @@ int get_default_n_block(format_tag_t matrix_b_tag) {
7577
case BA16a16b:
7678
case BA16a16b2a:
7779
case BA16a16b4a: return 16;
78-
default: return 64;
80+
default: {
81+
if (bgmmc.N == 16 || bgmmc.N == 32 || bgmmc.N == 64) return bgmmc.N;
82+
if (!mayiuse(sve_512)) {
83+
if (bgmmc.N <= 16)
84+
return 16;
85+
else {
86+
// It is observed that for M,K>512, N block of 64 works better provided that thread distribution is not hindered.
87+
if (bgmmc.N / 64 >= bgmmc.nthr && bgmmc.K > 512
88+
&& bgmmc.M > 512)
89+
return 64;
90+
else
91+
return 32;
92+
}
93+
94+
} else
95+
return 64;
96+
}
7997
}
8098
}
8199

@@ -178,7 +196,7 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag(
178196

179197
if (B_any_layout) {
180198
const int default_n_block = init_n_tag
181-
? get_default_n_block(format_tag::undef)
199+
? get_default_n_block(format_tag::undef, bgmmc)
182200
: bgmmc.N_blk;
183201
bgmmc.wei_tag = blocked_B_layouts_allowed
184202
? this->pick_blocked_B_layout(default_n_block)
@@ -580,14 +598,17 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
580598
const int nthr = bgmmc.nthr;
581599

582600
const int max_m_blk = nstl::min(/*64*/ 256, matmul.M);
583-
int min_m_blk = nstl::min(32, matmul.M); // max_m_blk
601+
// It is found that for 2d shapes min_m_blk = 128 works better than 32 for most of the shapes.
602+
int min_m = (matmul.batch > 1) ? 32 : 128;
603+
int min_m_blk = nstl::min(min_m, matmul.M); // max_m_blk
584604

585605
int n_blk = bgmmc.N_blk;
586606
const int n_chunks = div_up(matmul.N, n_blk);
587607
const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
588608
const int n_chunks_start = nstl::min(max_n_chunks, n_chunks);
589609

590-
int default_k_blk = 1024;
610+
//It is found that for M<512 k_blk of 128 works better than 1024 for most of the shapes.
611+
int default_k_blk = (matmul.M >= 512) ? 1024 : 128;
591612
int k_blk = nstl::min(matmul.K, default_k_blk);
592613
int start_nthr_k = 1;
593614

@@ -597,7 +618,22 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
597618
const bool low_parallel_work = static_cast<size_t>(nthr) > max_parallel;
598619
if (low_parallel_work) {
599620

600-
min_m_blk = nstl::min(matmul.M, 16);
621+
int best_m_blk = 0;
622+
float scr = 0, best_scr = 16 * nthr;
623+
for (int i = 16; i >= 4; i--) {
624+
scr = 0.7 * (matmul.M % i)
625+
+ 0.3 * std::abs(nthr - ((float)matmul.M / (float)i));
626+
if (scr < best_scr) {
627+
best_scr = scr;
628+
best_m_blk = i;
629+
}
630+
}
631+
min_m_blk = nstl::min(matmul.M, best_m_blk);
632+
// Here min_m_blk is set based on M value and no.of threads. Decreasing m_blk size will
633+
// increase no.of m blocks which might make better utilisation of threads. But it is found
634+
// that m_blk being a factor of M is more important than max thread utilisation.Therefore
635+
// in scoring that has been given more weightage(0.7). This was experimentally verified to
636+
// be the best hueristics with multiple shapes.
601637

602638
bool low_spatial_work = matmul.M <= 40;
603639
if (low_spatial_work) {
@@ -834,7 +870,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
834870

835871
VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);
836872

837-
bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag);
873+
bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag, bgmmc);
838874

839875
bgmmc.blocked_B = bm_conf_utils.get_blocked_B();
840876
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();

src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3+
* Copyright 2023-2024 FUJITSU LIMITED
34
*
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
@@ -312,7 +313,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
312313
void init_scratchpad(memory_tracking::registrar_t &scratchpad,
313314
const brgemm_matmul_conf_t &bgmmc);
314315

315-
int get_default_n_block(format_tag_t matrix_b_tag);
316+
int get_default_n_block(format_tag_t, brgemm_matmul_conf_t &bgmmc);
316317

317318
} // namespace matmul
318319
} // namespace aarch64

0 commit comments

Comments
 (0)