1
+ /* ******************************************************************************
2
+ * Copyright 2020-2023 Intel Corporation
3
+ * Copyright 2023 FUJITSU LIMITED
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ *******************************************************************************/
17
+ #ifndef CPU_AARCH64_BRGEMM_BRGEMM_HPP
18
+ #define CPU_AARCH64_BRGEMM_BRGEMM_HPP
19
+
20
+ #include " cpu/aarch64/brgemm/brgemm_types.hpp"
21
+
22
+ namespace dnnl {
23
+ namespace impl {
24
+ namespace cpu {
25
+ namespace aarch64 {
26
+ // / Initializes a BRGEMM descriptor
27
+ // /
28
+ // / @param brg Output BRGEMM descriptor
29
+ // / @param isa Target ISA of BRGEMM kernel
30
+ // / If isa is equal to 'isa_undef' maximum supported ISA on current
31
+ // / hardware will be used for BRGEMM kernel generation
32
+ // / @param type Type of batch
33
+ // / @param dt_a Data type of A matrix, can be
34
+ // / SVE_512: f32
35
+ // / @param dt_b Data type of B matrix
36
+ // / SVE_512: f32
37
+ // / @note
38
+ // / Data type of matrix C is f32 data type
39
+ // / @param transA Specifies the form of A used in the matrix multiplication
40
+ // / 'false' - A is not transposed, 'true' - A is transposed
41
+ // / @param transB Specifies the form of B used in the matrix multiplication
42
+ // / 'false' - B is not transposed, 'true' - B is transposed
43
+ // / @param layout Specifies whether two-dimensional array storage is row-major
44
+ // / (brgemm_row_major) or column-major (brgemm_col_major).
45
+ // / @param alpha Specifies the scalar alpha
46
+ // / @param beta Specifies the scalar beta
47
+ // / @param LDA Specifies the leading dimension of matrix A.
48
+ // / LDA must be at least max(1, K)
49
+ // / @param LDB Specifies the leading dimension of matrix B.
50
+ // / LDB must be at least max(1, N)
51
+ // / @param LDC Specifies the leading dimension of matrix C.
52
+ // / LDC must be at least max(1, N)
53
+ // / @param M Specifies the number of rows of the matrix A and of the matrix C.
54
+ // / @param N Specifies the number of columns of the matrix B and
55
+ // / the number of columns of the matrix C
56
+ // / @param K Specifies the number of columns of the matrix A and
57
+ // / the number of rows of the matrix B
58
+ // / @param strides Strides between the matrices in the batch. Can be nullptr.
59
+ // /
60
+ status_t DNNL_API brgemm_desc_init (brgemm_t *brg, cpu_isa_t isa,
61
+ brgemm_batch_kind_t type, impl::data_type_t dt_a,
62
+ impl::data_type_t dt_b, bool transA, bool transB,
63
+ brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB,
64
+ dim_t LDC, dim_t M, dim_t N, dim_t K,
65
+ const brgemm_strides_t *strides = nullptr );
66
+
67
+ // / Initializes a BRGEMM descriptor with B matrix as a diagonal matrix
68
+ // / represented in packed vector format.
69
+ // /
70
+ // / @param brg Output BRGEMM descriptor
71
+ // / @param isa Target ISA of BRGEMM kernel
72
+ // / If isa is equal to 'isa_undef' maximum supported ISA on current
73
+ // / hardware will be used for BRGEMM kernel generation
74
+ // / @param type Type of batch
75
+ // / @param dt_a Data type of A matrix can be: f32
76
+ // / @param dt_b Data type of B vector can be: f32
77
+ // / @note
78
+ // / Data type of matrix C f32 data type
79
+ // / @param transA Specifies the form of A used in the matrix multiplication
80
+ // / 'false' - A is not transposed, 'true' - A is transposed
81
+ // / @param layout Specifies whether two-dimensional array storage is row-major
82
+ // / (brgemm_row_major) or column-major (brgemm_col_major).
83
+ // / @param alpha Specifies the scalar alpha
84
+ // / @param beta Specifies the scalar beta
85
+ // / @param LDA Specifies the leading dimension of matrix A.
86
+ // / LDA must be at least max(1, N)
87
+ // / @param LDC Specifies the leading dimension of matrix C.
88
+ // / LDC must be at least max(1, N)
89
+ // / @param M Specifies the number of rows of the matrix A and C.
90
+ // / @param N Specifies the number of columns of the matrix A and C.
91
+ // /
92
+ status_t DNNL_API brdgmm_desc_init (brgemm_t *brg, cpu_isa_t isa,
93
+ brgemm_batch_kind_t type, impl::data_type_t dt_a,
94
+ impl::data_type_t dt_b, bool transA, brgemm_layout_t layout,
95
+ float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
96
+ const brgemm_strides_t *strides = nullptr );
97
+
98
+ // / Adds post-operations to BRGEMM descriptor
99
+ // /
100
+ // / @param brg Output BRGEMM descriptor
101
+ // / @param attr Primitive attributes (can be nullptr). Specifies post-ops
102
+ // / operations
103
+ // / @param dst_md Specifies the memory descriptor of the destination tensor,
104
+ // / needed for binary postops to determine broadcast type, as well as to
105
+ // / determine dst data type.
106
+ // / @param LDD Specifies the leading dimension of matrix D
107
+ // / LDD must be at least max(1, N)
108
+ // / @param dt_bias Specifies the data type Bias
109
+ // / Can be u8, s8, s32, bf16 or fp32
110
+ // /
111
+ status_t DNNL_API brgemm_desc_set_postops (brgemm_t *brg,
112
+ const primitive_attr_t *attr, const memory_desc_t *dst_md, int LDD,
113
+ impl::data_type_t dt_bias = impl::data_type::undef);
114
+
115
+ // / Adds BRGEMM attributes to BRGEMM descriptor
116
+ // /
117
+ // / @param brg Output BRGEMM descriptor
118
+ // / @param brgattr Specifies kernel attributes and hints: virtual padding,
119
+ // / maximum batch size, kernel loop order etc.
120
+ // /
121
+ status_t DNNL_API brgemm_desc_set_attr (
122
+ brgemm_t *brg, const brgemm_attr_t &brgattr);
123
+
124
+ // / Generates a BRGEMM kernel based on descriptor
125
+ // /
126
+ // / @param brg_kernel Output BRGEMM kernel
127
+ // / @param brg BRGEMM descriptor
128
+ // /
129
+ status_t DNNL_API brgemm_kernel_create (
130
+ brgemm_kernel_t **brg_kernel, const brgemm_t &brg);
131
+
132
+ // / Destroys a BRGEMM kernel
133
+ // /
134
+ // / @param brg_kernel BRGEMM kernel
135
+ // /
136
+ status_t DNNL_API brgemm_kernel_destroy (brgemm_kernel_t *brg_kernel);
137
+
138
+ // / Execute BRGEMM kernel (brgemm_addr version)
139
+ // /
140
+ // / @note
141
+ // / Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
142
+ // / descriptor
143
+ // /
144
+ // / @param brg_kernel BRGEMM kernel
145
+ // / @param bs Specifies the size of batch
146
+ // / @param batch Array of batch elements containing pointers to matrices
147
+ // / A,B and virtual padding for matrices A
148
+ // / @param ptr_C Pointer to destination matrix C
149
+ // / @param scratch Scratchpad memory needed in several scenarios
150
+ // /
151
+ void DNNL_API brgemm_kernel_execute (const brgemm_kernel_t *brg_kernel, int bs,
152
+ const brgemm_batch_element_t *batch, void *ptr_C,
153
+ void *scratch = nullptr );
154
+
155
+ // / Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
156
+ // /
157
+ // / @note
158
+ // / Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
159
+ // / descriptor
160
+ // /
161
+ // / @note
162
+ // / See the second note for `brgemm_kernel_execute` API.
163
+ // /
164
+ // / @param brg_kernel BRGEMM kernel
165
+ // / @param bs Specifies the size of batch
166
+ // / @param addr_A Pointer to first matrix A in the batch
167
+ // / @param addr_B Pointer to first matrix B in the batch
168
+ // / @param batch Array of batch elements containing offsets to matrices A,B
169
+ // / and virtual padding for matrix A. This parameter is ignored when
170
+ // / using fixed offsets.
171
+ // / @param ptr_C Pointer to destination matrix C
172
+ // / @param scratch Scratchpad memory needed in several scenarios
173
+ // /
174
+ void brgemm_kernel_execute (const brgemm_kernel_t *brg_kernel, int bs,
175
+ const void *addr_A, const void *addr_B,
176
+ const brgemm_batch_element_t *batch, void *ptr_C,
177
+ void *scratch = nullptr );
178
+
179
+ // / Execute BRGEMM kernel (brgemm_addr version)
180
+ // /
181
+ // / @note
182
+ // / BRGEMM kernel and post-operations will be executed
183
+ // /
184
+ // / @note
185
+ // / See the second note for `brgemm_kernel_execute` API.
186
+ // /
187
+ // / @param brg_kernel BRGEMM kernel
188
+ // / @param bs Specifies the size of batch
189
+ // / @param batch Array of batch elements containing pointers to matrices A,B
190
+ // / and virtual padding for matrices A
191
+ // / @param ptr_C Pointer to matrix C
192
+ // / @param ptr_D Pointer to destination matrix D
193
+ // / @param post_ops_data Specifies tensors and data used in post processing
194
+ // / phase
195
+ // / @param scratch Scratchpad memory needed in several scenarios
196
+ // /
197
+ void DNNL_API brgemm_kernel_execute_postops (const brgemm_kernel_t *brg_kernel,
198
+ int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
199
+ const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr );
200
+
201
+ // / Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
202
+ // /
203
+ // / @note
204
+ // / BRGEMM kernel and post-operations will be executed
205
+ // /
206
+ // / @note
207
+ // / See the second note for `brgemm_kernel_execute` API.
208
+ // /
209
+ // / @param brg_kernel BRGEMM kernel
210
+ // / @param bs Specifies the size of batch
211
+ // / @param addr_A Pointer to first matrix A in the batch
212
+ // / @param addr_B Pointer to first matrix B in the batch
213
+ // / @param batch Array of batch elements containing offsets to matrices A,B
214
+ // / and virtual padding for matrices A. This parameter is ignored when
215
+ // / using fixed offsets.
216
+ // / @param ptr_C Pointer to destination matrix C
217
+ // / @param ptr_D Pointer to destination matrix D
218
+ // / @param post_ops_data Specifies tensors and data used in post processing
219
+ // / phase
220
+ // / @param scratch Scratchpad memory needed in several scenarios
221
+ // /
222
+ void brgemm_kernel_execute_postops (const brgemm_kernel_t *brg_kernel, int bs,
223
+ const void *addr_A, const void *addr_B,
224
+ const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
225
+ const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr );
226
+
227
+ } // namespace aarch64
228
+ } // namespace cpu
229
+ } // namespace impl
230
+ } // namespace dnnl
231
+
232
+ #endif
233
+
234
+ // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
0 commit comments