forked from uxlfoundation/oneDNN
-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathbrgemm.hpp
300 lines (283 loc) · 13.5 KB
/
brgemm.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
/*******************************************************************************
* Copyright 2020-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef CPU_X64_BRGEMM_BRGEMM_HPP
#define CPU_X64_BRGEMM_BRGEMM_HPP
#include "cpu/x64/brgemm/brgemm_types.hpp"
namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
/// Initializes a BRGEMM descriptor
///
/// @param brg Output BRGEMM descriptor
/// @param isa Target ISA of BRGEMM kernel
/// If isa is equal to 'isa_undef' maximum supported ISA on current
/// hardware will be used for BRGEMM kernel generation
/// @param type Type of batch
/// @param dt_a Data type of A matrix, can be
/// AVX512: f32, u8(row-major layout), s8(column-major layout), bf16, f16
/// AMX: u8, s8, bf16, f16
/// @param dt_b Data type of B matrix
/// AVX512: f32, s8(row-major layout), u8(column-major layout), bf16, f16
/// AMX: u8, s8, bf16, f16
/// @note
/// Data type of matrix C depends on data types of matrices A and B
/// If A and B have integer u8/s8 data type, C has int32 data type
/// If A and B have bf16 or f16 or f32 data type, C has f32 data type
/// @param transA Specifies the form of A used in the matrix multiplication
/// 'false' - A is not transposed, 'true' - A is transposed
/// @param transB Specifies the form of B used in the matrix multiplication
/// 'false' - B is not transposed, 'true' - B is transposed
/// @param layout Specifies whether two-dimensional array storage is row-major
/// (brgemm_row_major) or column-major (brgemm_col_major).
/// @param alpha Specifies the scalar alpha
/// @param beta Specifies the scalar beta
/// @param LDA Specifies the leading dimension of matrix A.
/// LDA must be at least max(1, K)
/// @param LDB Specifies the leading dimension of matrix B.
/// LDB must be at least max(1, N)
/// @param LDC Specifies the leading dimension of matrix C.
/// LDC must be at least max(1, N)
/// @param M Specifies the number of rows of the matrix A and of the matrix C.
/// @param N Specifies the number of columns of the matrix B and
/// the number of columns of the matrix C
/// @param K Specifies the number of columns of the matrix A and
/// the number of rows of the matrix B
/// @param strides Strides between the matrices in the batch. Can be nullptr.
/// TODO: what does "Can be nullptr" mean?
///
status_t DNNL_API brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, bool transA, bool transB,
brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB,
dim_t LDC, dim_t M, dim_t N, dim_t K,
const brgemm_strides_t *strides = nullptr,
bool is_weights_decompression = false, bool is_src_dynamic_quantization = false,
const memory_desc_t *wei_md = nullptr, const primitive_attr_t *attr = nullptr);
/// Initializes a BRGEMM descriptor with B matrix as a diagonal matrix
/// represented in packed vector format.
///
/// @param brg Output BRGEMM descriptor
/// @param isa Target ISA of BRGEMM kernel
/// If isa is equal to 'isa_undef' maximum supported ISA on current
/// hardware will be used for BRGEMM kernel generation
/// @param type Type of batch
/// @param dt_a Data type of A matrix can be: f32, u8, bf16, f16
/// @param dt_b Data type of B vector can be: f32, s8, bf16, f16
/// @note
/// Data type of matrix C depends on data types of matrices A and vector B
/// If A and B have integer u8/s8 data type, C has int32 data type
/// If A and B have bf16 or f16 or f32 data type, C has f32 data type
/// @param transA Specifies the form of A used in the matrix multiplication
/// 'false' - A is not transposed, 'true' - A is transposed
/// @param layout Specifies whether two-dimensional array storage is row-major
/// (brgemm_row_major) or column-major (brgemm_col_major).
/// @param alpha Specifies the scalar alpha
/// @param beta Specifies the scalar beta
/// @param LDA Specifies the leading dimension of matrix A.
/// LDA must be at least max(1, N)
/// @param LDC Specifies the leading dimension of matrix C.
/// LDC must be at least max(1, N)
/// @param M Specifies the number of rows of the matrix A and C.
/// @param N Specifies the number of columns of the matrix A and C.
/// @param strides - TODO: missing documentation.
///
status_t DNNL_API brdgmm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, bool transA, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
const brgemm_strides_t *strides = nullptr);
/// Adds post-operations to BRGEMM descriptor
///
/// @param brg Output BRGEMM descriptor
/// @param attr Primitive attributes (can be nullptr). Specifies post-ops
/// operations
/// @param dst_md Specifies the memory descriptor of the destination tensor,
/// needed for binary postops to determine broadcast type, as well as to
/// determine dst data type.
/// @param LDD Specifies the leading dimension of matrix D
/// LDD must be at least max(1, N)
/// TODO: why LDD can't be obtained from dst_md directly?
/// @param dt_bias Specifies the data type Bias
/// Can be u8, s8, s32, bf16, f16 or fp32
///
status_t DNNL_API brgemm_desc_set_postops(brgemm_desc_t *brg,
const primitive_attr_t *attr, const memory_desc_t *dst_md, dim_t LDD,
impl::data_type_t dt_bias = impl::data_type::undef,
bool is_weights_decompression = false);
/// Adds BRGEMM attributes to BRGEMM descriptor
///
/// @param brg Output BRGEMM descriptor
/// @param brgattr Specifies kernel attributes and hints: virtual padding,
/// maximum batch size, kernel loop order etc.
///
status_t DNNL_API brgemm_desc_set_attr(
brgemm_desc_t *brg, const brgemm_attr_t &brgattr);
/// Generates a BRGEMM kernel based on descriptor
///
/// @param brg_kernel Output BRGEMM kernel
/// @param brg BRGEMM descriptor
///
status_t DNNL_API brgemm_kernel_create(
brgemm_kernel_t **brg_kernel, const brgemm_desc_t &brg);
/// Destroys a BRGEMM kernel
///
/// @param brg_kernel BRGEMM kernel
///
status_t DNNL_API brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel);
/// Execute BRGEMM kernel (brgemm_addr version)
///
/// @note
/// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
/// descriptor
///
/// @note
/// In row major mode matrix B (matrix A for column major) is expected to be
/// in a VNNI-friendly format, which requires 4 consecutive elements of K
/// dimension for int8 data type, 2 elements for bfloat16 data type and no
/// requirements for f32 and f16 data types.
///
/// @param brg_kernel BRGEMM kernel
/// @param bs Specifies the size of batch
/// @param batch Array of batch elements containing pointers to matrices
/// A,B and virtual padding for matrices A
/// @param ptr_C Pointer to destination matrix C
/// @param scratch Scratchpad memory needed in several scenarios:
/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
/// * In rest scenarios is not used.
/// @param dynamic_values TODO: missing doc
///
void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
const brgemm_batch_element_t *batch, void *ptr_C,
void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
///
/// @note
/// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
/// descriptor
///
/// @note
/// See the second note for `brgemm_kernel_execute` API.
///
/// @param brg_kernel BRGEMM kernel
/// @param bs Specifies the size of batch
/// @param addr_A Pointer to first matrix A in the batch
/// @param addr_B Pointer to first matrix B in the batch
/// @param batch Array of batch elements containing offsets to matrices A,B
/// and virtual padding for matrix A. This parameter is ignored when
/// using fixed offsets.
/// @param ptr_C Pointer to destination matrix C
/// @param scratch Scratchpad memory needed in several scenarios:
/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
/// * In rest scenarios is not used.
/// @param dynamic_values TODO: missing doc
///
void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
const void *addr_A, const void *addr_B,
const brgemm_batch_element_t *batch, void *ptr_C,
void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
/// Execute BRGEMM kernel (brgemm_addr version)
///
/// @note
/// BRGEMM kernel and post-operations will be executed
///
/// @note
/// See the second note for `brgemm_kernel_execute` API.
///
/// @param brg_kernel BRGEMM kernel
/// @param bs Specifies the size of batch
/// @param batch Array of batch elements containing pointers to matrices A,B
/// and virtual padding for matrices A
/// @param ptr_C Pointer to matrix C
/// @param ptr_D Pointer to destination matrix D
/// @param post_ops_data Specifies tensors and data used in post processing
/// phase
/// @param scratch Scratchpad memory needed in several scenarios:
/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
/// * Where: pre-VNNI hardware; When: s8s8 kernel; For: compensation buffer.
/// * In rest scenarios is not used.
/// @param dynamic_values TODO: missing doc
///
void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
///
/// @note
/// BRGEMM kernel and post-operations will be executed
///
/// @note
/// See the second note for `brgemm_kernel_execute` API.
///
/// @param brg_kernel BRGEMM kernel
/// @param bs Specifies the size of batch
/// @param addr_A Pointer to first matrix A in the batch
/// @param addr_B Pointer to first matrix B in the batch
/// @param batch Array of batch elements containing offsets to matrices A,B
/// and virtual padding for matrices A. This parameter is ignored when
/// using fixed offsets.
/// @param ptr_C Pointer to destination matrix C
/// @param ptr_D Pointer to destination matrix D
/// @param post_ops_data Specifies tensors and data used in post processing
/// phase
/// @param scratch Scratchpad memory needed in several scenarios:
/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
/// * Where: pre-VNNI hardware; When: s8s8 kernel; For: compensation buffer.
/// * In rest scenarios is not used.
/// @param dynamic_values TODO: missing doc
///
void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
int bs, const void *addr_A, const void *addr_B,
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
/// AMX utilities: Creates a palette based on BRGEMM descriptor
///
/// @note
/// This call expects brgemm_desc_t object completely set up, thus, must be
/// used after `brgemm_desc_set_attr` call for non-empty attributes.
///
/// @note
/// Caller is expected to subsequently configure AMX tiles by calling
/// amx_tile_configure(palette).
///
/// @param brg Input BRGeMM descriptor
/// @param palette Output 64 bytes array initialized with tile configuration if
/// returned status is status::success. When any other status is returned,
/// the `palette` is not initialized and can't be used.
///
/// TODO: replace `char[64]` with a proper type that can express itself if it
/// was properly initialized and whether it's empty. Current API is broken in a
/// sense that multiple different scenarios are considered equal, whether
/// it's not AMX, or blocking is completely broken or unsupported.
status_t DNNL_API brgemm_init_tiles(const brgemm_desc_t &brg, char palette[64]);
} // namespace x64
} // namespace cpu
} // namespace impl
} // namespace dnnl
#endif
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s