Skip to content

Commit 157fb81

Browse files
[FORK][FEATURE] DQ IP: performance enhansments
- allocate aux accums regs on stack - precompute grouped src sums - optimize pointer arithmetic - reduce aux vecs count requred for the microkernel
1 parent 1789b1e commit 157fb81

12 files changed

+418
-202
lines changed

src/common/memory_tracking.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ enum {
305305
key_decompression_zero_points,
306306
key_src_quantized,
307307
key_src_dequantized_scales,
308+
key_src_grouped_sum,
308309
// These two keys should always be the last ones,
309310
// even though they are not in alphabetical order
310311
key_nested,

src/cpu/x64/brgemm/brgemm.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ void brgemm_desc_t::cleanup_dst_md() {
8282
void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
8383
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
8484
const brgemm_dynamic_values_t *dynamic_values,
85-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
85+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
86+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
8687
brgemm_kernel_params_t brgemm_p;
8788

8889
brgemm_p.batch = batch;
@@ -105,6 +106,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
105106
brgemm_p.ptr_wei_scales = ptr_wei_scales;
106107
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
107108
brgemm_p.ptr_src_scales = ptr_src_scales;
109+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
108110
brgemm_p.ic = ic;
109111

110112
assert(brg_kernel);
@@ -116,7 +118,8 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
116118
const void *addr_A, const void *addr_B,
117119
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
118120
const brgemm_dynamic_values_t *dynamic_values,
119-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
121+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
122+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
120123
brgemm_kernel_params_t brgemm_p;
121124

122125
brgemm_p.batch = batch;
@@ -133,6 +136,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
133136
brgemm_p.ptr_wei_scales = ptr_wei_scales;
134137
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
135138
brgemm_p.ptr_src_scales = ptr_src_scales;
139+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
136140
brgemm_p.ic = ic;
137141
if (dynamic_values) {
138142
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
@@ -148,7 +152,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
148152
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
149153
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
150154
const brgemm_dynamic_values_t *dynamic_values,
151-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
155+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
156+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
152157
brgemm_kernel_params_t brgemm_p;
153158

154159
brgemm_p.batch = batch;
@@ -178,6 +183,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
178183
brgemm_p.ptr_wei_scales = ptr_wei_scales;
179184
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
180185
brgemm_p.ptr_src_scales = ptr_src_scales;
186+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
181187
brgemm_p.ic = ic;
182188
if (dynamic_values) {
183189
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
@@ -194,7 +200,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
194200
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
195201
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
196202
const brgemm_dynamic_values_t *dynamic_values,
197-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
203+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
204+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
198205
brgemm_kernel_params_t brgemm_p;
199206

200207
brgemm_p.batch = batch;
@@ -224,6 +231,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
224231
brgemm_p.ptr_wei_scales = ptr_wei_scales;
225232
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
226233
brgemm_p.ptr_src_scales = ptr_src_scales;
234+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
227235
brgemm_p.ic = ic;
228236
if (dynamic_values) {
229237
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
@@ -318,6 +326,12 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
318326

319327
CHECK(brgemm_blocking(brg));
320328

329+
brg->src_sum_group_size = wei_d.dims()[1];
330+
if (brg->with_src_dyn_quant) {
331+
brg->src_sum_group_size = brg->rd_block;
332+
brg->src_grouped_sum_stride = div_up(wei_d.dims()[1], brg->src_sum_group_size);
333+
}
334+
321335
// avx2_vnni_2 kernel with xf16 data type requires blocked weights.
322336
if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16()
323337
&& brg->LDB % brg->ld_block > 0)

src/cpu/x64/brgemm/brgemm.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
175175
void *scratch = nullptr,
176176
const brgemm_dynamic_values_t *dynamic_values = nullptr,
177177
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
178-
const void *ptr_src_scales = nullptr, size_t ic = 0);
178+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
179179

180180
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
181181
///
@@ -205,7 +205,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
205205
void *scratch = nullptr,
206206
const brgemm_dynamic_values_t *dynamic_values = nullptr,
207207
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
208-
const void *ptr_src_scales = nullptr, size_t ic = 0);
208+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
209209

210210
/// Execute BRGEMM kernel (brgemm_addr version)
211211
///
@@ -234,7 +234,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
234234
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
235235
const brgemm_dynamic_values_t *dynamic_values = nullptr,
236236
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
237-
const void *ptr_src_scales = nullptr, size_t ic = 0);
237+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
238238

239239
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
240240
///
@@ -267,7 +267,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
267267
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
268268
const brgemm_dynamic_values_t *dynamic_values = nullptr,
269269
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
270-
const void *ptr_src_scales = nullptr, size_t ic = 0);
270+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
271271

272272
/// AMX utilities: Creates a palette based on BRGEMM descriptor
273273
///

src/cpu/x64/brgemm/brgemm_types.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ struct brgemm_desc_t {
321321
bool with_src_dyn_quant = false;
322322
int src_scales_group_size = 0;
323323
int src_scales_stride = 0;
324+
int src_sum_group_size = 0;
325+
int src_grouped_sum_stride = 0;
324326

325327
bool is_row_major() const {
326328
assert(layout != brgemm_layout_undef);
@@ -500,6 +502,7 @@ struct brgemm_kernel_params_t {
500502
const void *ptr_wei_scales = nullptr;
501503
const void *ptr_wei_zero_points = nullptr;
502504
const void *ptr_src_scales = nullptr;
505+
const void *ptr_src_grouped_sum = nullptr;
503506
size_t ic;
504507
dim_t dynamic_LDA = 0;
505508
dim_t dynamic_LDB = 0;

src/cpu/x64/brgemm/brgemm_utils.cpp

+11-8
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,10 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
230230
if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5;
231231
if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2;
232232
if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1;
233-
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1;
234-
if (brg->with_src_dyn_quant) max_bcast_block -= 2;
235-
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
233+
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0 && !brg->with_src_dyn_quant) max_bcast_block -= 1;
234+
if (brg->with_src_dyn_quant) max_bcast_block -= 1;
236235

237236
max_bcast_block /= adj_ld_block2;
238-
if (brg->with_src_dyn_quant) {
239-
max_bcast_block /= 2;
240-
}
241237

242238
return max_bcast_block;
243239
}
@@ -301,15 +297,22 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
301297
= (brg->is_f16 && brg->isa_impl == avx512_core_fp16)
302298
? 1
303299
: data_type_vnni_granularity(brg->dt_a);
300+
304301
int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4;
305-
if (brg->with_grouped_wei_decomp) {
302+
if (brg->with_grouped_wei_decomp && !brg->with_src_dyn_quant) {
306303
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
307304
min_group_size = nstl::min(min_group_size, brg->src_scales_group_size);
308305
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
309306
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
307+
brg->rd_block = rd_unroll * vnni_granularity;
308+
} else if (brg->with_src_dyn_quant) {
309+
brg->rd_block = brg->src_scales_group_size;
310+
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
311+
brg->rd_block = nstl::min(brg->rd_block, min_group_size);
312+
} else {
313+
brg->rd_block = rd_unroll * vnni_granularity;
310314
}
311315

312-
brg->rd_block = rd_unroll * vnni_granularity;
313316
brg->rdb = brg->reduce_dim / brg->rd_block;
314317
brg->rdb_tail = brg->reduce_dim % brg->rd_block;
315318

0 commit comments

Comments
 (0)