Skip to content

Commit a892566

Browse files
committed
cpu: aarch64: add brgemm bwd data support for block size 8 and 16
1 parent 95ac968 commit a892566

5 files changed

+323
-64
lines changed

src/cpu/aarch64/jit_brgemm_conv.cpp

+54-59
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -43,8 +43,8 @@ using namespace jit_uni_brgemm_conv_comp_pad_kernel;
4343
#define ndims_pick(v5, v4, v3) \
4444
((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0)
4545

46-
template <cpu_isa_t isa, bool use_inversion>
47-
void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init_batch(int icc,
46+
template <cpu_isa_t isa>
47+
void brgemm_convolution_fwd_t<isa>::pd_t::init_batch(int icc,
4848
const char *src_base, const char *wei_base, int n_ic_blocks,
4949
int ic_block_s, int iid_b, int iih_b, int iiw_b,
5050
const dim_t *const __restrict kw_top_vpads,
@@ -117,8 +117,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init_batch(int icc,
117117
}
118118
}
119119

120-
template <cpu_isa_t isa, bool use_inversion>
121-
inline void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::get_A_B(int icc,
120+
template <cpu_isa_t isa>
121+
inline void brgemm_convolution_fwd_t<isa>::pd_t::get_A_B(int icc,
122122
const char *src_base, const char *wei_base, int ic_block_s, int iid_b,
123123
int iih_b, int iiw_b, int kd_b, int kh_b, const void *&ptrA,
124124
const void *&ptrB) const {
@@ -147,10 +147,9 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::get_A_B(int icc,
147147
ptrB = wei_base_kh + wei_kw * wei_kw_offset;
148148
}
149149

150-
template <cpu_isa_t isa, bool use_inversion>
151-
status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::add_brg_descriptor(
152-
int vM, int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b,
153-
int kh_e) {
150+
template <cpu_isa_t isa>
151+
status_t brgemm_convolution_fwd_t<isa>::pd_t::add_brg_descriptor(int vM,
152+
int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
154153

155154
const auto src_type = src_md(0)->data_type;
156155
const auto wei_type = weights_md(0)->data_type;
@@ -287,9 +286,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::add_brg_descriptor(
287286
return status::success;
288287
}
289288

290-
template <cpu_isa_t isa, bool use_inversion>
291-
status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
292-
engine_t *engine) {
289+
template <cpu_isa_t isa>
290+
status_t brgemm_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) {
293291
using namespace data_type;
294292
using namespace utils;
295293
brgemm_descriptors_
@@ -306,7 +304,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
306304
// executing 'use_inversion == true' as FWD. This can only work if the
307305
// diff_src_desc and diff_dst_desc are defined in the aforementioned.
308306
const convolution_desc_t &cd = *desc();
309-
if (use_inversion
307+
if (cd.use_inversion
310308
&& one_of(true, types::is_zero_md(&cd.diff_src_desc),
311309
types::is_zero_md(&cd.diff_dst_desc)))
312310
return status::unimplemented;
@@ -336,6 +334,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
336334
// For exec_base it makes sense to use unrolled kernel only if
337335
// there is no padding by width.
338336
// 2. For exec_trans block by kw is always KW
337+
// 3. 'false' is used intentionally to disable the condition, ensuring that
338+
// the assert fails only when jcp_.use_uker is true, regardless of exec_type.
339339
assert(IMPLICATION(jcp_.use_uker,
340340
false && one_of(jcp_.exec_type, exec_base, exec_trans)));
341341
assert(IMPLICATION(jcp_.use_interleave_stores, jcp_.use_uker));
@@ -535,13 +535,12 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
535535
return status::success;
536536
}
537537

538-
template <cpu_isa_t isa, bool use_inversion>
539-
brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_convolution_fwd_t(
540-
const pd_t *apd)
538+
template <cpu_isa_t isa>
539+
brgemm_convolution_fwd_t<isa>::brgemm_convolution_fwd_t(const pd_t *apd)
541540
: primitive_t(apd), bias_d(pd()->weights_md(1)) {}
542541

543-
template <cpu_isa_t isa, bool use_inversion>
544-
void brgemm_convolution_fwd_t<isa, use_inversion>::get_kw_range(
542+
template <cpu_isa_t isa>
543+
void brgemm_convolution_fwd_t<isa>::get_kw_range(
545544
int ow, int &kw_s, int &kw_full_s, int &kw_full_f, int &kw_f) const {
546545
// This function needed for exec_base only
547546
const auto _pd = pd();
@@ -570,8 +569,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::get_kw_range(
570569
if (kw_full_f == -1) kw_full_s = kw_full_f = kw_f;
571570
}
572571

573-
template <cpu_isa_t isa, bool use_inversion>
574-
inline void brgemm_convolution_fwd_t<isa, use_inversion>::get_ow_range(
572+
template <cpu_isa_t isa>
573+
inline void brgemm_convolution_fwd_t<isa>::get_ow_range(
575574
int ow, int kw, int &ow_s, int &ow_f) const {
576575
// This function needed for exec_base only
577576
const auto _pd = pd();
@@ -602,9 +601,9 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::get_ow_range(
602601
ow_f = nstl::min(nstl::max(ow_f, ow_s), ow + M);
603602
}
604603

605-
template <cpu_isa_t isa, bool use_inversion>
606-
status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_brg_kernel(int M,
607-
int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
604+
template <cpu_isa_t isa>
605+
status_t brgemm_convolution_fwd_t<isa>::add_brg_kernel(int M, int i_N, int i_K,
606+
int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
608607
if (M <= 0) return status::success;
609608
const auto _pd = pd();
610609
const auto &jcp = _pd->jcp_;
@@ -623,8 +622,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_brg_kernel(int M,
623622
return status::success;
624623
}
625624

626-
template <cpu_isa_t isa, bool use_inversion>
627-
status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
625+
template <cpu_isa_t isa>
626+
status_t brgemm_convolution_fwd_t<isa>::add_po_kernel(
628627
brgemm_t *bcfg, int ker_idx, bool is_init) {
629628
if (!bcfg) return status::success;
630629
const auto _pd = pd();
@@ -641,8 +640,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
641640
return status::success;
642641
}
643642

644-
template <cpu_isa_t isa, bool use_inversion>
645-
void brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernels(
643+
template <cpu_isa_t isa>
644+
void brgemm_convolution_fwd_t<isa>::add_po_kernels(
646645
int i_N, int init_bcast_dim, int po_bcast_dim) {
647646
const auto _pd = pd();
648647
const auto &jcp = _pd->jcp_;
@@ -676,10 +675,10 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernels(
676675
}
677676
}
678677
}
679-
template <cpu_isa_t isa, bool use_inversion>
680-
int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_ker_idx(
681-
const int kd_b, const int kd_e, const int kh_b, const int kh_e,
682-
const int kw_b, const int kw_e) const {
678+
template <cpu_isa_t isa>
679+
int brgemm_convolution_fwd_t<isa>::get_comp_ker_idx(const int kd_b,
680+
const int kd_e, const int kh_b, const int kh_e, const int kw_b,
681+
const int kw_e) const {
683682
const auto _pd = pd();
684683
const auto &jcp = _pd->jcp_;
685684

@@ -696,11 +695,10 @@ int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_ker_idx(
696695
return -1;
697696
}
698697

699-
template <cpu_isa_t isa, bool use_inversion>
700-
inline int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_offset(
701-
const int g, const int ocb, const int ow, const int kd_b,
702-
const int kd_e, const int kh_b, const int kh_e, const int kw_b,
703-
const int kw_e) const {
698+
template <cpu_isa_t isa>
699+
inline int brgemm_convolution_fwd_t<isa>::get_comp_offset(const int g,
700+
const int ocb, const int ow, const int kd_b, const int kd_e,
701+
const int kh_b, const int kh_e, const int kw_b, const int kw_e) const {
704702
const auto _pd = pd();
705703
const auto &jcp = _pd->jcp_;
706704

@@ -714,8 +712,8 @@ inline int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_offset(
714712
: (g * jcp.nb_oc + ocb) * jcp.oc_block;
715713
}
716714

717-
template <cpu_isa_t isa, bool use_inversion>
718-
status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {
715+
template <cpu_isa_t isa>
716+
status_t brgemm_convolution_fwd_t<isa>::init(engine_t *engine) {
719717

720718
const auto _pd = pd();
721719
const auto &jcp = _pd->jcp_;
@@ -1054,8 +1052,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {
10541052

10551053
return status::success;
10561054
}
1057-
template <cpu_isa_t isa, bool use_inversion>
1058-
struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
1055+
template <cpu_isa_t isa>
1056+
struct brgemm_convolution_fwd_t<isa>::brgemm_thread_ctx_t {
10591057
brgemm_thread_ctx_t(brgemm_exec_ctx_t &brgemm_ctx_, int ithr_,
10601058
brgemm_batch_element_t *__restrict brg_batch_, char *c_buffer_,
10611059
char *wsp_tile_)
@@ -1082,9 +1080,8 @@ struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
10821080
const float *dst_scales {nullptr};
10831081
};
10841082

1085-
template <cpu_isa_t isa, bool use_inversion>
1086-
status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
1087-
const exec_ctx_t &ctx) const {
1083+
template <cpu_isa_t isa>
1084+
status_t brgemm_convolution_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
10881085
const auto _pd = pd();
10891086
const auto &jcp = _pd->jcp_;
10901087

@@ -1266,8 +1263,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
12661263
return status::success;
12671264
}
12681265

1269-
template <cpu_isa_t isa, bool use_inversion>
1270-
status_t brgemm_convolution_fwd_t<isa, use_inversion>::cal_compensation(
1266+
template <cpu_isa_t isa>
1267+
status_t brgemm_convolution_fwd_t<isa>::cal_compensation(
12711268
const char *__restrict weights, int32_t *src_zp_buffer,
12721269
int32_t *s8s8_comp_buffer) const {
12731270
const auto _pd = pd();
@@ -1332,8 +1329,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::cal_compensation(
13321329
return status::success;
13331330
}
13341331

1335-
template <cpu_isa_t isa, bool use_inversion>
1336-
void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
1332+
template <cpu_isa_t isa>
1333+
void brgemm_convolution_fwd_t<isa>::perform_outwork(
13371334
const brgemm_thread_ctx_t &btc, char *dst_base, const char *bias_w,
13381335
int ow, int g_oc, bool is_oc_tail, int ker_ow_s, int ker_ow_f, int kd_l,
13391336
int kh_l, bool maybe_do_init, bool do_postwork,
@@ -1417,8 +1414,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
14171414
}
14181415
}
14191416

1420-
template <cpu_isa_t isa, bool use_inversion>
1421-
inline void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
1417+
template <cpu_isa_t isa>
1418+
inline void brgemm_convolution_fwd_t<isa>::call_brgemm_kernel(
14221419
const brgemm_thread_ctx_t &btc, const brgemm_kernel_t *brg_ker,
14231420
int batch_size, char *ptr_C, char *ptr_D, const char *bias_w, int g_oc,
14241421
bool do_postops, int comp_ker_offs, bool do_only_comp) const {
@@ -1467,8 +1464,8 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
14671464
ptr_C, static_cast<void *>(btc.wsp_tile));
14681465
}
14691466

1470-
template <cpu_isa_t isa, bool use_inversion>
1471-
void brgemm_convolution_fwd_t<isa, use_inversion>::maybe_conv_inp(int ithr,
1467+
template <cpu_isa_t isa>
1468+
void brgemm_convolution_fwd_t<isa>::maybe_conv_inp(int ithr,
14721469
const char *__restrict src, char *__restrict inp_buffer,
14731470
uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int odb,
14741471
int ohb, int owb, int last_g, int last_n, int last_icc, int last_odb,
@@ -1648,9 +1645,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::maybe_conv_inp(int ithr,
16481645
char *ptr_D; \
16491646
int kd_b(0), kd_e(0), kh_b(0), kh_e(0), k_l(0), iiw_b(0);
16501647

1651-
template <cpu_isa_t isa, bool use_inversion>
1652-
void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
1653-
brgemm_thread_ctx_t &btc) const {
1648+
template <cpu_isa_t isa>
1649+
void brgemm_convolution_fwd_t<isa>::ker_base(brgemm_thread_ctx_t &btc) const {
16541650

16551651
const auto _pd = pd();
16561652
const auto &jcp = _pd->jcp_;
@@ -1799,8 +1795,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
17991795
}
18001796
}
18011797

1802-
template <cpu_isa_t isa, bool use_inversion>
1803-
void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
1798+
template <cpu_isa_t isa>
1799+
void brgemm_convolution_fwd_t<isa>::ker_trans(
18041800
brgemm_thread_ctx_t &btc, char *inp_buffer) const {
18051801

18061802
const auto _pd = pd();
@@ -1924,9 +1920,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
19241920
}
19251921
}
19261922

1927-
template <cpu_isa_t isa, bool use_inversion>
1928-
void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(
1929-
brgemm_thread_ctx_t &btc) const {
1923+
template <cpu_isa_t isa>
1924+
void brgemm_convolution_fwd_t<isa>::ker_vpad(brgemm_thread_ctx_t &btc) const {
19301925

19311926
const auto _pd = pd();
19321927
const auto &jcp = _pd->jcp_;

src/cpu/aarch64/jit_brgemm_conv.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -41,7 +41,7 @@ namespace impl {
4141
namespace cpu {
4242
namespace aarch64 {
4343

44-
template <cpu_isa_t isa, bool use_inversion = false>
44+
template <cpu_isa_t isa>
4545
struct brgemm_convolution_fwd_t : public primitive_t {
4646

4747
struct brgemm_thread_ctx_t;
@@ -117,7 +117,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {
117117
}
118118

119119
inline int maybe_invert(int k, int K) const {
120-
return use_inversion ? K - 1 - k : k;
120+
return desc()->use_inversion ? K - 1 - k : k;
121121
};
122122

123123
void init_batch(int icc, const char *src_base, const char *wei_base,
@@ -210,7 +210,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {
210210
}
211211

212212
inline int maybe_invert_range(int k, int k_inv, int K) const {
213-
return use_inversion ? K - k_inv : k;
213+
return pd()->desc()->use_inversion ? K - k_inv : k;
214214
};
215215

216216
void get_kw_range(

0 commit comments

Comments
 (0)