Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu: aarch64: add brgemm bwd data support for block size 8 and 16 #2865

Merged
merged 1 commit into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 54 additions & 59 deletions src/cpu/aarch64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024-2025 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,8 +43,8 @@ using namespace jit_uni_brgemm_conv_comp_pad_kernel;
#define ndims_pick(v5, v4, v3) \
((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0)

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init_batch(int icc,
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::pd_t::init_batch(int icc,
const char *src_base, const char *wei_base, int n_ic_blocks,
int ic_block_s, int iid_b, int iih_b, int iiw_b,
const dim_t *const __restrict kw_top_vpads,
Expand Down Expand Up @@ -117,8 +117,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init_batch(int icc,
}
}

template <cpu_isa_t isa, bool use_inversion>
inline void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::get_A_B(int icc,
template <cpu_isa_t isa>
inline void brgemm_convolution_fwd_t<isa>::pd_t::get_A_B(int icc,
const char *src_base, const char *wei_base, int ic_block_s, int iid_b,
int iih_b, int iiw_b, int kd_b, int kh_b, const void *&ptrA,
const void *&ptrB) const {
Expand Down Expand Up @@ -147,10 +147,9 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::get_A_B(int icc,
ptrB = wei_base_kh + wei_kw * wei_kw_offset;
}

template <cpu_isa_t isa, bool use_inversion>
status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::add_brg_descriptor(
int vM, int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b,
int kh_e) {
template <cpu_isa_t isa>
status_t brgemm_convolution_fwd_t<isa>::pd_t::add_brg_descriptor(int vM,
int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {

const auto src_type = src_md(0)->data_type;
const auto wei_type = weights_md(0)->data_type;
Expand Down Expand Up @@ -287,9 +286,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::add_brg_descriptor(
return status::success;
}

template <cpu_isa_t isa, bool use_inversion>
status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
engine_t *engine) {
template <cpu_isa_t isa>
status_t brgemm_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) {
using namespace data_type;
using namespace utils;
brgemm_descriptors_
Expand All @@ -306,7 +304,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
// executing 'use_inversion == true' as FWD. This can only work if the
// diff_src_desc and diff_dst_desc are defined in the aforementioned.
const convolution_desc_t &cd = *desc();
if (use_inversion
if (cd.use_inversion
&& one_of(true, types::is_zero_md(&cd.diff_src_desc),
types::is_zero_md(&cd.diff_dst_desc)))
return status::unimplemented;
Expand Down Expand Up @@ -336,6 +334,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
// For exec_base it makes sense to use unrolled kernel only if
// there is no padding by width.
// 2. For exec_trans block by kw is always KW
// 3. 'false' is used intentionally to disable the condition, ensuring that
// the assert fails only when jcp_.use_uker is true, regardless of exec_type.
assert(IMPLICATION(jcp_.use_uker,
false && one_of(jcp_.exec_type, exec_base, exec_trans)));
assert(IMPLICATION(jcp_.use_interleave_stores, jcp_.use_uker));
Expand Down Expand Up @@ -535,13 +535,12 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
return status::success;
}

template <cpu_isa_t isa, bool use_inversion>
brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_convolution_fwd_t(
const pd_t *apd)
template <cpu_isa_t isa>
brgemm_convolution_fwd_t<isa>::brgemm_convolution_fwd_t(const pd_t *apd)
: primitive_t(apd), bias_d(pd()->weights_md(1)) {}

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::get_kw_range(
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::get_kw_range(
int ow, int &kw_s, int &kw_full_s, int &kw_full_f, int &kw_f) const {
// This function needed for exec_base only
const auto _pd = pd();
Expand Down Expand Up @@ -570,8 +569,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::get_kw_range(
if (kw_full_f == -1) kw_full_s = kw_full_f = kw_f;
}

template <cpu_isa_t isa, bool use_inversion>
inline void brgemm_convolution_fwd_t<isa, use_inversion>::get_ow_range(
template <cpu_isa_t isa>
inline void brgemm_convolution_fwd_t<isa>::get_ow_range(
int ow, int kw, int &ow_s, int &ow_f) const {
// This function needed for exec_base only
const auto _pd = pd();
Expand Down Expand Up @@ -602,9 +601,9 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::get_ow_range(
ow_f = nstl::min(nstl::max(ow_f, ow_s), ow + M);
}

template <cpu_isa_t isa, bool use_inversion>
status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_brg_kernel(int M,
int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
template <cpu_isa_t isa>
status_t brgemm_convolution_fwd_t<isa>::add_brg_kernel(int M, int i_N, int i_K,
int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
if (M <= 0) return status::success;
const auto _pd = pd();
const auto &jcp = _pd->jcp_;
Expand All @@ -623,8 +622,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_brg_kernel(int M,
return status::success;
}

template <cpu_isa_t isa, bool use_inversion>
status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
template <cpu_isa_t isa>
status_t brgemm_convolution_fwd_t<isa>::add_po_kernel(
brgemm_t *bcfg, int ker_idx, bool is_init) {
if (!bcfg) return status::success;
const auto _pd = pd();
Expand All @@ -641,8 +640,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
return status::success;
}

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernels(
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::add_po_kernels(
int i_N, int init_bcast_dim, int po_bcast_dim) {
const auto _pd = pd();
const auto &jcp = _pd->jcp_;
Expand Down Expand Up @@ -676,10 +675,10 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernels(
}
}
}
template <cpu_isa_t isa, bool use_inversion>
int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_ker_idx(
const int kd_b, const int kd_e, const int kh_b, const int kh_e,
const int kw_b, const int kw_e) const {
template <cpu_isa_t isa>
int brgemm_convolution_fwd_t<isa>::get_comp_ker_idx(const int kd_b,
const int kd_e, const int kh_b, const int kh_e, const int kw_b,
const int kw_e) const {
const auto _pd = pd();
const auto &jcp = _pd->jcp_;

Expand All @@ -696,11 +695,10 @@ int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_ker_idx(
return -1;
}

template <cpu_isa_t isa, bool use_inversion>
inline int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_offset(
const int g, const int ocb, const int ow, const int kd_b,
const int kd_e, const int kh_b, const int kh_e, const int kw_b,
const int kw_e) const {
template <cpu_isa_t isa>
inline int brgemm_convolution_fwd_t<isa>::get_comp_offset(const int g,
const int ocb, const int ow, const int kd_b, const int kd_e,
const int kh_b, const int kh_e, const int kw_b, const int kw_e) const {
const auto _pd = pd();
const auto &jcp = _pd->jcp_;

Expand All @@ -714,8 +712,8 @@ inline int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_offset(
: (g * jcp.nb_oc + ocb) * jcp.oc_block;
}

template <cpu_isa_t isa, bool use_inversion>
status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {
template <cpu_isa_t isa>
status_t brgemm_convolution_fwd_t<isa>::init(engine_t *engine) {

const auto _pd = pd();
const auto &jcp = _pd->jcp_;
Expand Down Expand Up @@ -1054,8 +1052,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {

return status::success;
}
template <cpu_isa_t isa, bool use_inversion>
struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
template <cpu_isa_t isa>
struct brgemm_convolution_fwd_t<isa>::brgemm_thread_ctx_t {
brgemm_thread_ctx_t(brgemm_exec_ctx_t &brgemm_ctx_, int ithr_,
brgemm_batch_element_t *__restrict brg_batch_, char *c_buffer_,
char *wsp_tile_)
Expand All @@ -1082,9 +1080,8 @@ struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
const float *dst_scales {nullptr};
};

template <cpu_isa_t isa, bool use_inversion>
status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
const exec_ctx_t &ctx) const {
template <cpu_isa_t isa>
status_t brgemm_convolution_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
const auto _pd = pd();
const auto &jcp = _pd->jcp_;

Expand Down Expand Up @@ -1266,8 +1263,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
return status::success;
}

template <cpu_isa_t isa, bool use_inversion>
status_t brgemm_convolution_fwd_t<isa, use_inversion>::cal_compensation(
template <cpu_isa_t isa>
status_t brgemm_convolution_fwd_t<isa>::cal_compensation(
const char *__restrict weights, int32_t *src_zp_buffer,
int32_t *s8s8_comp_buffer) const {
const auto _pd = pd();
Expand Down Expand Up @@ -1332,8 +1329,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::cal_compensation(
return status::success;
}

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::perform_outwork(
const brgemm_thread_ctx_t &btc, char *dst_base, const char *bias_w,
int ow, int g_oc, bool is_oc_tail, int ker_ow_s, int ker_ow_f, int kd_l,
int kh_l, bool maybe_do_init, bool do_postwork,
Expand Down Expand Up @@ -1417,8 +1414,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
}
}

template <cpu_isa_t isa, bool use_inversion>
inline void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
template <cpu_isa_t isa>
inline void brgemm_convolution_fwd_t<isa>::call_brgemm_kernel(
const brgemm_thread_ctx_t &btc, const brgemm_kernel_t *brg_ker,
int batch_size, char *ptr_C, char *ptr_D, const char *bias_w, int g_oc,
bool do_postops, int comp_ker_offs, bool do_only_comp) const {
Expand Down Expand Up @@ -1467,8 +1464,8 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
ptr_C, static_cast<void *>(btc.wsp_tile));
}

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::maybe_conv_inp(int ithr,
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::maybe_conv_inp(int ithr,
const char *__restrict src, char *__restrict inp_buffer,
uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int odb,
int ohb, int owb, int last_g, int last_n, int last_icc, int last_odb,
Expand Down Expand Up @@ -1648,9 +1645,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::maybe_conv_inp(int ithr,
char *ptr_D; \
int kd_b(0), kd_e(0), kh_b(0), kh_e(0), k_l(0), iiw_b(0);

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
brgemm_thread_ctx_t &btc) const {
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::ker_base(brgemm_thread_ctx_t &btc) const {

const auto _pd = pd();
const auto &jcp = _pd->jcp_;
Expand Down Expand Up @@ -1799,8 +1795,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
}
}

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::ker_trans(
brgemm_thread_ctx_t &btc, char *inp_buffer) const {

const auto _pd = pd();
Expand Down Expand Up @@ -1924,9 +1920,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
}
}

template <cpu_isa_t isa, bool use_inversion>
void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(
brgemm_thread_ctx_t &btc) const {
template <cpu_isa_t isa>
void brgemm_convolution_fwd_t<isa>::ker_vpad(brgemm_thread_ctx_t &btc) const {

const auto _pd = pd();
const auto &jcp = _pd->jcp_;
Expand Down
8 changes: 4 additions & 4 deletions src/cpu/aarch64/jit_brgemm_conv.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024-2025 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,7 +41,7 @@ namespace impl {
namespace cpu {
namespace aarch64 {

template <cpu_isa_t isa, bool use_inversion = false>
template <cpu_isa_t isa>
struct brgemm_convolution_fwd_t : public primitive_t {

struct brgemm_thread_ctx_t;
Expand Down Expand Up @@ -117,7 +117,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {
}

inline int maybe_invert(int k, int K) const {
return use_inversion ? K - 1 - k : k;
return desc()->use_inversion ? K - 1 - k : k;
};

void init_batch(int icc, const char *src_base, const char *wei_base,
Expand Down Expand Up @@ -210,7 +210,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {
}

inline int maybe_invert_range(int k, int k_inv, int K) const {
return use_inversion ? K - k_inv : k;
return pd()->desc()->use_inversion ? K - k_inv : k;
};

void get_kw_range(
Expand Down
Loading
Loading