Skip to content

Commit d1eaa3c

Browse files
cpu: aarch64: Expand ARM SVE support for matrix multiplication (#1818)
Co-authored-by: Shreyas-fuj <shreyas.shankar@fujitsu.com>
1 parent 24a914c commit d1eaa3c

27 files changed

+9355
-45
lines changed

src/cpu/aarch64/brgemm/brgemm.cpp

+566
Large diffs are not rendered by default.

src/cpu/aarch64/brgemm/brgemm.hpp

+234
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
/*******************************************************************************
2+
* Copyright 2020-2023 Intel Corporation
3+
* Copyright 2023 FUJITSU LIMITED
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*******************************************************************************/
17+
#ifndef CPU_AARCH64_BRGEMM_BRGEMM_HPP
18+
#define CPU_AARCH64_BRGEMM_BRGEMM_HPP
19+
20+
#include "cpu/aarch64/brgemm/brgemm_types.hpp"
21+
22+
namespace dnnl {
23+
namespace impl {
24+
namespace cpu {
25+
namespace aarch64 {
26+
/// Initializes a BRGEMM descriptor
27+
///
28+
/// @param brg Output BRGEMM descriptor
29+
/// @param isa Target ISA of BRGEMM kernel
30+
/// If isa is equal to 'isa_undef' maximum supported ISA on current
31+
/// hardware will be used for BRGEMM kernel generation
32+
/// @param type Type of batch
33+
/// @param dt_a Data type of A matrix, can be
34+
/// SVE_512: f32
35+
/// @param dt_b Data type of B matrix
36+
/// SVE_512: f32
37+
/// @note
38+
/// Data type of matrix C is f32 data type
39+
/// @param transA Specifies the form of A used in the matrix multiplication
40+
/// 'false' - A is not transposed, 'true' - A is transposed
41+
/// @param transB Specifies the form of B used in the matrix multiplication
42+
/// 'false' - B is not transposed, 'true' - B is transposed
43+
/// @param layout Specifies whether two-dimensional array storage is row-major
44+
/// (brgemm_row_major) or column-major (brgemm_col_major).
45+
/// @param alpha Specifies the scalar alpha
46+
/// @param beta Specifies the scalar beta
47+
/// @param LDA Specifies the leading dimension of matrix A.
48+
/// LDA must be at least max(1, K)
49+
/// @param LDB Specifies the leading dimension of matrix B.
50+
/// LDB must be at least max(1, N)
51+
/// @param LDC Specifies the leading dimension of matrix C.
52+
/// LDC must be at least max(1, N)
53+
/// @param M Specifies the number of rows of the matrix A and of the matrix C.
54+
/// @param N Specifies the number of columns of the matrix B and
55+
/// the number of columns of the matrix C
56+
/// @param K Specifies the number of columns of the matrix A and
57+
/// the number of rows of the matrix B
58+
/// @param strides Strides between the matrices in the batch. Can be nullptr.
59+
///
60+
status_t DNNL_API brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
61+
brgemm_batch_kind_t type, impl::data_type_t dt_a,
62+
impl::data_type_t dt_b, bool transA, bool transB,
63+
brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB,
64+
dim_t LDC, dim_t M, dim_t N, dim_t K,
65+
const brgemm_strides_t *strides = nullptr);
66+
67+
/// Initializes a BRGEMM descriptor with B matrix as a diagonal matrix
68+
/// represented in packed vector format.
69+
///
70+
/// @param brg Output BRGEMM descriptor
71+
/// @param isa Target ISA of BRGEMM kernel
72+
/// If isa is equal to 'isa_undef' maximum supported ISA on current
73+
/// hardware will be used for BRGEMM kernel generation
74+
/// @param type Type of batch
75+
/// @param dt_a Data type of A matrix can be: f32
76+
/// @param dt_b Data type of B vector can be: f32
77+
/// @note
78+
/// Data type of matrix C f32 data type
79+
/// @param transA Specifies the form of A used in the matrix multiplication
80+
/// 'false' - A is not transposed, 'true' - A is transposed
81+
/// @param layout Specifies whether two-dimensional array storage is row-major
82+
/// (brgemm_row_major) or column-major (brgemm_col_major).
83+
/// @param alpha Specifies the scalar alpha
84+
/// @param beta Specifies the scalar beta
85+
/// @param LDA Specifies the leading dimension of matrix A.
86+
/// LDA must be at least max(1, N)
87+
/// @param LDC Specifies the leading dimension of matrix C.
88+
/// LDC must be at least max(1, N)
89+
/// @param M Specifies the number of rows of the matrix A and C.
90+
/// @param N Specifies the number of columns of the matrix A and C.
91+
///
92+
status_t DNNL_API brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
93+
brgemm_batch_kind_t type, impl::data_type_t dt_a,
94+
impl::data_type_t dt_b, bool transA, brgemm_layout_t layout,
95+
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
96+
const brgemm_strides_t *strides = nullptr);
97+
98+
/// Adds post-operations to BRGEMM descriptor
99+
///
100+
/// @param brg Output BRGEMM descriptor
101+
/// @param attr Primitive attributes (can be nullptr). Specifies post-ops
102+
/// operations
103+
/// @param dst_md Specifies the memory descriptor of the destination tensor,
104+
/// needed for binary postops to determine broadcast type, as well as to
105+
/// determine dst data type.
106+
/// @param LDD Specifies the leading dimension of matrix D
107+
/// LDD must be at least max(1, N)
108+
/// @param dt_bias Specifies the data type Bias
109+
/// Can be u8, s8, s32, bf16 or fp32
110+
///
111+
status_t DNNL_API brgemm_desc_set_postops(brgemm_t *brg,
112+
const primitive_attr_t *attr, const memory_desc_t *dst_md, int LDD,
113+
impl::data_type_t dt_bias = impl::data_type::undef);
114+
115+
/// Adds BRGEMM attributes to BRGEMM descriptor
116+
///
117+
/// @param brg Output BRGEMM descriptor
118+
/// @param brgattr Specifies kernel attributes and hints: virtual padding,
119+
/// maximum batch size, kernel loop order etc.
120+
///
121+
status_t DNNL_API brgemm_desc_set_attr(
122+
brgemm_t *brg, const brgemm_attr_t &brgattr);
123+
124+
/// Generates a BRGEMM kernel based on descriptor
125+
///
126+
/// @param brg_kernel Output BRGEMM kernel
127+
/// @param brg BRGEMM descriptor
128+
///
129+
status_t DNNL_API brgemm_kernel_create(
130+
brgemm_kernel_t **brg_kernel, const brgemm_t &brg);
131+
132+
/// Destroys a BRGEMM kernel
133+
///
134+
/// @param brg_kernel BRGEMM kernel
135+
///
136+
status_t DNNL_API brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel);
137+
138+
/// Execute BRGEMM kernel (brgemm_addr version)
139+
///
140+
/// @note
141+
/// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
142+
/// descriptor
143+
///
144+
/// @param brg_kernel BRGEMM kernel
145+
/// @param bs Specifies the size of batch
146+
/// @param batch Array of batch elements containing pointers to matrices
147+
/// A,B and virtual padding for matrices A
148+
/// @param ptr_C Pointer to destination matrix C
149+
/// @param scratch Scratchpad memory needed in several scenarios
150+
///
151+
void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
152+
const brgemm_batch_element_t *batch, void *ptr_C,
153+
void *scratch = nullptr);
154+
155+
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
156+
///
157+
/// @note
158+
/// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
159+
/// descriptor
160+
///
161+
/// @note
162+
/// See the second note for `brgemm_kernel_execute` API.
163+
///
164+
/// @param brg_kernel BRGEMM kernel
165+
/// @param bs Specifies the size of batch
166+
/// @param addr_A Pointer to first matrix A in the batch
167+
/// @param addr_B Pointer to first matrix B in the batch
168+
/// @param batch Array of batch elements containing offsets to matrices A,B
169+
/// and virtual padding for matrix A. This parameter is ignored when
170+
/// using fixed offsets.
171+
/// @param ptr_C Pointer to destination matrix C
172+
/// @param scratch Scratchpad memory needed in several scenarios
173+
///
174+
void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
175+
const void *addr_A, const void *addr_B,
176+
const brgemm_batch_element_t *batch, void *ptr_C,
177+
void *scratch = nullptr);
178+
179+
/// Execute BRGEMM kernel (brgemm_addr version)
180+
///
181+
/// @note
182+
/// BRGEMM kernel and post-operations will be executed
183+
///
184+
/// @note
185+
/// See the second note for `brgemm_kernel_execute` API.
186+
///
187+
/// @param brg_kernel BRGEMM kernel
188+
/// @param bs Specifies the size of batch
189+
/// @param batch Array of batch elements containing pointers to matrices A,B
190+
/// and virtual padding for matrices A
191+
/// @param ptr_C Pointer to matrix C
192+
/// @param ptr_D Pointer to destination matrix D
193+
/// @param post_ops_data Specifies tensors and data used in post processing
194+
/// phase
195+
/// @param scratch Scratchpad memory needed in several scenarios
196+
///
197+
void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
198+
int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
199+
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr);
200+
201+
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
202+
///
203+
/// @note
204+
/// BRGEMM kernel and post-operations will be executed
205+
///
206+
/// @note
207+
/// See the second note for `brgemm_kernel_execute` API.
208+
///
209+
/// @param brg_kernel BRGEMM kernel
210+
/// @param bs Specifies the size of batch
211+
/// @param addr_A Pointer to first matrix A in the batch
212+
/// @param addr_B Pointer to first matrix B in the batch
213+
/// @param batch Array of batch elements containing offsets to matrices A,B
214+
/// and virtual padding for matrices A. This parameter is ignored when
215+
/// using fixed offsets.
216+
/// @param ptr_C Pointer to destination matrix C
217+
/// @param ptr_D Pointer to destination matrix D
218+
/// @param post_ops_data Specifies tensors and data used in post processing
219+
/// phase
220+
/// @param scratch Scratchpad memory needed in several scenarios
221+
///
222+
void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
223+
const void *addr_A, const void *addr_B,
224+
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
225+
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr);
226+
227+
} // namespace aarch64
228+
} // namespace cpu
229+
} // namespace impl
230+
} // namespace dnnl
231+
232+
#endif
233+
234+
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
* Copyright 2024 FUJITSU LIMITED
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*******************************************************************************/
17+
18+
#include "cpu/aarch64/brgemm/brgemm_containers.hpp"
19+
#include "cpu/aarch64/brgemm/jit_brdgmm_kernel.hpp"
20+
21+
namespace dnnl {
22+
namespace impl {
23+
namespace cpu {
24+
namespace aarch64 {
25+
26+
using namespace dnnl::impl::utils;
27+
28+
namespace brgemm_containers {
29+
30+
#ifdef BRGEMM_KERNEL_GLOBAL_STORAGE
31+
std::set<std::shared_ptr<brgemm_kernel_t>,
32+
decltype(brgemm_kernel_container_t::brgemm_kernel_cmp) *>
33+
brgemm_kernel_container_t::set_
34+
= std::set<std::shared_ptr<brgemm_kernel_t>,
35+
decltype(brgemm_kernel_container_t::brgemm_kernel_cmp) *>(
36+
brgemm_kernel_container_t::brgemm_kernel_cmp);
37+
#endif
38+
39+
bool brgemm_desc_container_t::insert(int idx, brgemm_t &brg,
40+
const std::vector<char> &bd_mask,
41+
const std::vector<brgemm_batch_element_t> &static_offsets) {
42+
bd_mask_list_.push_back(bd_mask);
43+
brg.brgattr.bd_mask = bd_mask_list_.back().data();
44+
45+
static_offsets_list_.push_back(static_offsets);
46+
brg.brgattr.static_offsets = static_offsets_list_.back().data();
47+
48+
const auto ret = set_.insert(brg);
49+
refs_[idx] = &(*ret.first);
50+
// if there was no insertion then clean bd_mask and static_offsets
51+
if (!ret.second) {
52+
bd_mask_list_.pop_back();
53+
static_offsets_list_.pop_back();
54+
}
55+
return ret.second;
56+
}
57+
58+
bool brgemm_kernel_container_t::brgemm_kernel_cmp(
59+
const std::shared_ptr<brgemm_kernel_t> &lhs,
60+
const std::shared_ptr<brgemm_kernel_t> &rhs) {
61+
const auto lsz = lhs->get_jit_generator()->getSize();
62+
const auto rsz = rhs->get_jit_generator()->getSize();
63+
if (lsz != rsz) return (lsz < rsz);
64+
const auto lcode = lhs->get_jit_generator()->CodeGenerator::getCode();
65+
const auto rcode = rhs->get_jit_generator()->CodeGenerator::getCode();
66+
return (std::memcmp(lcode, rcode, lsz) < 0);
67+
}
68+
69+
status_t brgemm_kernel_container_t::insert(int idx, const brgemm_t *brg) {
70+
// Use two level hashing of brgemm kernels:
71+
// 1. Try to find entry in local brgemm_map_ using brgemm descriptor as a
72+
// key (we can check if brgemm descriptor is unique inside brgemm primitive)
73+
// 2. Only if we do not find entry in local brgemm_map_ then try to find
74+
// entry in kernel storage using kernel code as key
75+
const auto brgemm_it = brgemm_map_.find(brg);
76+
if (brgemm_it == brgemm_map_.end()) {
77+
brgemm_kernel_t *brg_kernel = nullptr;
78+
status_t s = brgemm_kernel_create(&brg_kernel, *brg);
79+
if (s != status::success) {
80+
delete brg_kernel;
81+
return s;
82+
}
83+
std::shared_ptr<brgemm_kernel_t> sptr(brg_kernel);
84+
lock_write();
85+
const auto kernel_ret = set_.insert(sptr);
86+
refs_[idx] = kernel_ret.first->get();
87+
unlock_write();
88+
const auto brgemm_ret = brgemm_map_.insert({brg, refs_[idx]});
89+
if (!brgemm_ret.second) return status::runtime_error;
90+
} else {
91+
refs_[idx] = brgemm_it->second;
92+
}
93+
return status::success;
94+
}
95+
96+
} // namespace brgemm_containers
97+
} // namespace aarch64
98+
} // namespace cpu
99+
} // namespace impl
100+
} // namespace dnnl
101+
102+
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s

0 commit comments

Comments
 (0)