Skip to content

Commit 45ce1c8

Browse files
authored
cpu: aarch64: optimising memory/thread utilization in BRGEMM Matmul (#2103)
1 parent 3d1e89a commit 45ce1c8

File tree

4 files changed

+77
-70
lines changed

4 files changed

+77
-70
lines changed

src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp

+30-61
Original file line numberDiff line numberDiff line change
@@ -767,10 +767,38 @@ void jit_brgemm_kernel_t::read_params() {
767767
void jit_brgemm_kernel_t::zero_accumulators(int bd_block2, bool is_bdb_tail,
768768
int ld_block2, bool is_ld_tail, bool skip_accumulation) {
769769
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
770+
const bool need_to_apply_beta = brg.beta != 0.f;
770771
for_(int bd = 0; bd < bd_block; bd++)
771772
for (int ld = 0; ld < ld_block2; ld++) {
772773
auto zmm = accm(ld_block2, bd, ld);
773-
eor(zmm.d, zmm.d, zmm.d);
774+
// This part is moved here from apply_alpha_beta function so that fadd instruction can be avoided.
775+
// This is also required only when K is blocked.
776+
if (need_to_apply_beta) {
777+
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
778+
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
779+
780+
const int offset = C_offset(bd, ld);
781+
782+
int base_offset = 0;
783+
auto x_addr = reg_aux_C;
784+
785+
if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
786+
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
787+
base_offset = offset;
788+
x_addr = reg_tmp_;
789+
}
790+
LD_MUL_VL(ld1w, zmm.s, k_mask, x_addr, offset - base_offset, 4);
791+
792+
const bool need_init_beta_vmm = brg.beta != 1.f;
793+
auto vmm_beta = z_tail_mask();
794+
if (need_init_beta_vmm) {
795+
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
796+
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
797+
dup(vmm_beta.s, wreg_tmp);
798+
fmul(zmm.s, zmm.s, vmm_beta.s);
799+
}
800+
} else
801+
eor(zmm.d, zmm.d, zmm.d);
774802
}
775803
}
776804

@@ -791,58 +819,6 @@ void jit_brgemm_kernel_t::apply_alpha_beta(
791819
if (dq2ps_required) { scvtf(vmm.s, P_ALL_ONE / T_m, vmm.s); }
792820
if (apply_alpha) { fmul(vmm.s, vmm.s, vmm_alpha.s); }
793821
}
794-
795-
if (brg.beta == 0.f) return;
796-
const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
797-
const bool need_init_beta_vmm = brg.beta != 1.f;
798-
auto vmm_prev_dst = z_tmp_1();
799-
auto vmm_beta = z_tail_mask();
800-
if (need_init_beta_vmm) {
801-
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
802-
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
803-
dup(vmm_beta.s, wreg_tmp);
804-
}
805-
806-
int base_offset = 0;
807-
auto x_addr = reg_aux_C;
808-
for_(int bd = 0; bd < bd_block; bd++)
809-
for (int ld = 0; ld < ld_block2; ld++) {
810-
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
811-
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
812-
auto vmm = accm(ld_block2, bd, ld);
813-
if (use_vadd_for_beta) {
814-
if (brg.is_int8) {
815-
assert(!"unsupported\n");
816-
} else {
817-
ZRegS z_masked = vmm.s;
818-
ZRegS z(vmm.getIdx());
819-
820-
const int offset = C_offset(bd, ld);
821-
822-
if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
823-
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
824-
base_offset = offset;
825-
x_addr = reg_tmp_;
826-
}
827-
LD_MUL_VL(ld1w, vmm_prev_dst.s, k_mask, x_addr,
828-
offset - base_offset, 4);
829-
if (is_ld_tail) {
830-
movprfx(z_masked, k_mask / T_z, z);
831-
fadd(z_masked, k_mask / T_m, vmm_prev_dst.s);
832-
} else {
833-
fadd(z_masked, z_masked, vmm_prev_dst.s);
834-
}
835-
}
836-
} else {
837-
add_imm(X_DEFAULT_ADDR, reg_aux_C, C_offset(bd, ld), X_TMP_0);
838-
ld1w(vmm_prev_dst.s, k_mask / T_z, ptr(X_DEFAULT_ADDR));
839-
if (brg.beta == 1.f) {
840-
fadd(vmm.s, vmm.s, vmm_prev_dst.s);
841-
} else {
842-
fmla(vmm.s, P_ALL_ONE / T_m, vmm_prev_dst.s, vmm_beta.s);
843-
}
844-
}
845-
}
846822
}
847823

848824
void jit_brgemm_kernel_t::apply_post_ops(
@@ -1465,7 +1441,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
14651441
int base_offset = 0;
14661442

14671443
for (int rd = 0; rd < rd_loop; rd += brg.rd_step) {
1468-
int prefetch_count_B = 0;
14691444
for (int ld = 0; ld < ld_block2; ld++) {
14701445
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
14711446
if (brg.dt_b == data_type::f16) {
@@ -1497,13 +1472,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
14971472
broadcast(bcst(), A_offset(bd, rd),
14981473
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
14991474
}
1500-
if (prefetch_count_B < ld_block2) {
1501-
add_imm(X_DEFAULT_ADDR, reg_aux_B,
1502-
B_offset(prefetch_count_B++, rd)
1503-
+ brg.LDB * brg.rd_block * brg.typesize_B,
1504-
X_TMP_0);
1505-
prfm(PLDL1KEEP, ptr(X_DEFAULT_ADDR));
1506-
}
1475+
//The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
15071476
for (int ld = 0; ld < ld_block2; ld++) {
15081477
auto zmm = accm(ld_block2, bd, ld);
15091478
if (is_emdbd) {

src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init(
8585
matmul_conf_for_reorder_.K = dims[ndims - 2];
8686
matmul_conf_for_reorder_.N = dims[ndims - 1];
8787
matmul_conf_for_reorder_.wei_n_blk = matmul_conf_for_reorder_.N_blk
88-
= matmul_conf_for_reorder_.LDB = matmul::get_default_n_block(otag);
88+
= matmul_conf_for_reorder_.LDB
89+
= matmul::get_default_n_block(otag, matmul_conf_for_reorder_);
8990
matmul_conf_for_reorder_.N_tail
9091
= matmul_conf_for_reorder_.N % matmul_conf_for_reorder_.N_blk;
9192
matmul_conf_for_reorder_.K_blk = 16 * vnni_granularity;

src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp

+43-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3+
* Copyright 2023-2024 FUJITSU LIMITED
34
* Copyright 2024 Arm Ltd. and affiliates
45
*
56
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -48,7 +49,8 @@ using namespace dnnl::impl::utils;
4849
using namespace data_type;
4950
using namespace format_tag;
5051

51-
int get_default_n_block(format_tag_t matrix_b_tag) {
52+
int get_default_n_block(
53+
format_tag_t matrix_b_tag, brgemm_matmul_conf_t &bgmmc) {
5254
// Note: consider using weights mem_descriptor 'inner_blks' to
5355
// return B's inner block for non-default cases.
5456
switch (matrix_b_tag) {
@@ -76,7 +78,23 @@ int get_default_n_block(format_tag_t matrix_b_tag) {
7678
case BA16a16b:
7779
case BA16a16b2a:
7880
case BA16a16b4a: return 16;
79-
default: return 64;
81+
default: {
82+
if (bgmmc.N == 16 || bgmmc.N == 32 || bgmmc.N == 64) return bgmmc.N;
83+
if (!mayiuse(sve_512)) {
84+
if (bgmmc.N <= 16)
85+
return 16;
86+
else {
87+
// It is observed that for M,K>512, N block of 64 works better provided that thread distribution is not hindered.
88+
if (bgmmc.N / 64 >= bgmmc.nthr && bgmmc.K > 512
89+
&& bgmmc.M > 512)
90+
return 64;
91+
else
92+
return 32;
93+
}
94+
95+
} else
96+
return 64;
97+
}
8098
}
8199
}
82100

@@ -178,7 +196,7 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag(
178196

179197
if (B_any_layout) {
180198
const int default_n_block = init_n_tag
181-
? get_default_n_block(format_tag::undef)
199+
? get_default_n_block(format_tag::undef, bgmmc)
182200
: bgmmc.N_blk;
183201
bgmmc.wei_tag = blocked_B_layouts_allowed
184202
? this->pick_blocked_B_layout(default_n_block)
@@ -576,14 +594,17 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
576594
const int nthr = bgmmc.nthr;
577595

578596
const int max_m_blk = nstl::min(/*64*/ 256, matmul.M);
579-
int min_m_blk = nstl::min(32, matmul.M); // max_m_blk
597+
// It is found that for 2d shapes min_m_blk = 128 works better than 32 for most of the shapes.
598+
int min_m = (matmul.batch > 1) ? 32 : 128;
599+
int min_m_blk = nstl::min(min_m, matmul.M); // max_m_blk
580600

581601
int n_blk = bgmmc.N_blk;
582602
const int n_chunks = div_up(matmul.N, n_blk);
583603
const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
584604
const int n_chunks_start = nstl::min(max_n_chunks, n_chunks);
585605

586-
int default_k_blk = 1024;
606+
//It is found that for M<512 k_blk of 128 works better than 1024 for most of the shapes.
607+
int default_k_blk = (matmul.M >= 512) ? 1024 : 128;
587608
int k_blk = nstl::min(matmul.K, default_k_blk);
588609
int start_nthr_k = 1;
589610

@@ -593,7 +614,22 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
593614
const bool low_parallel_work = static_cast<size_t>(nthr) > max_parallel;
594615
if (low_parallel_work) {
595616

596-
min_m_blk = nstl::min(matmul.M, 16);
617+
int best_m_blk = 0;
618+
float scr = 0, best_scr = 16 * nthr;
619+
for (int i = 16; i >= 4; i--) {
620+
scr = 0.7 * (matmul.M % i)
621+
+ 0.3 * std::abs(nthr - ((float)matmul.M / (float)i));
622+
if (scr < best_scr) {
623+
best_scr = scr;
624+
best_m_blk = i;
625+
}
626+
}
627+
min_m_blk = nstl::min(matmul.M, best_m_blk);
628+
// Here min_m_blk is set based on M value and no.of threads. Decreasing m_blk size will
629+
// increase no.of m blocks which might make better utilisation of threads. But it is found
630+
// that m_blk being a factor of M is more important than max thread utilisation.Therefore
631+
// in scoring that has been given more weightage(0.7). This was experimentally verified to
632+
// be the best hueristics with multiple shapes.
597633

598634
bool low_spatial_work = matmul.M <= 40;
599635
if (low_spatial_work) {
@@ -829,7 +865,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
829865

830866
VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);
831867

832-
bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag);
868+
bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag, bgmmc);
833869

834870
bgmmc.blocked_B = bm_conf_utils.get_blocked_B();
835871
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();

src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3+
* Copyright 2023-2024 FUJITSU LIMITED
34
*
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
@@ -312,7 +313,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
312313
void init_scratchpad(memory_tracking::registrar_t &scratchpad,
313314
const brgemm_matmul_conf_t &bgmmc);
314315

315-
int get_default_n_block(format_tag_t matrix_b_tag);
316+
int get_default_n_block(format_tag_t, brgemm_matmul_conf_t &bgmmc);
316317

317318
} // namespace matmul
318319
} // namespace aarch64

0 commit comments

Comments
 (0)