Skip to content

Commit a870aae

Browse files
[FORK][FEATURE] DQ IP: precompute grouped src sums
1 parent d421730 commit a870aae

12 files changed

+190
-83
lines changed

src/common/memory_tracking.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ enum {
305305
key_decompression_zero_points,
306306
key_src_quantized,
307307
key_src_dequantized_scales,
308+
key_src_grouped_sum,
308309
// These two keys should always be the last ones,
309310
// even though they are not in alphabetical order
310311
key_nested,

src/cpu/x64/brgemm/brgemm.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ void brgemm_desc_t::cleanup_dst_md() {
8282
void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
8383
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
8484
const brgemm_dynamic_values_t *dynamic_values,
85-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
85+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
86+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
8687
brgemm_kernel_params_t brgemm_p;
8788

8889
brgemm_p.batch = batch;
@@ -105,6 +106,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
105106
brgemm_p.ptr_wei_scales = ptr_wei_scales;
106107
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
107108
brgemm_p.ptr_src_scales = ptr_src_scales;
109+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
108110
brgemm_p.ic = ic;
109111

110112
assert(brg_kernel);
@@ -116,7 +118,8 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
116118
const void *addr_A, const void *addr_B,
117119
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
118120
const brgemm_dynamic_values_t *dynamic_values,
119-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
121+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
122+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
120123
brgemm_kernel_params_t brgemm_p;
121124

122125
brgemm_p.batch = batch;
@@ -133,6 +136,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
133136
brgemm_p.ptr_wei_scales = ptr_wei_scales;
134137
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
135138
brgemm_p.ptr_src_scales = ptr_src_scales;
139+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
136140
brgemm_p.ic = ic;
137141
if (dynamic_values) {
138142
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
@@ -148,7 +152,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
148152
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
149153
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
150154
const brgemm_dynamic_values_t *dynamic_values,
151-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
155+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
156+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
152157
brgemm_kernel_params_t brgemm_p;
153158

154159
brgemm_p.batch = batch;
@@ -178,6 +183,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
178183
brgemm_p.ptr_wei_scales = ptr_wei_scales;
179184
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
180185
brgemm_p.ptr_src_scales = ptr_src_scales;
186+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
181187
brgemm_p.ic = ic;
182188
if (dynamic_values) {
183189
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
@@ -194,7 +200,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
194200
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
195201
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
196202
const brgemm_dynamic_values_t *dynamic_values,
197-
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
203+
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
204+
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
198205
brgemm_kernel_params_t brgemm_p;
199206

200207
brgemm_p.batch = batch;
@@ -224,6 +231,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
224231
brgemm_p.ptr_wei_scales = ptr_wei_scales;
225232
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
226233
brgemm_p.ptr_src_scales = ptr_src_scales;
234+
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
227235
brgemm_p.ic = ic;
228236
if (dynamic_values) {
229237
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
@@ -318,6 +326,12 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
318326

319327
CHECK(brgemm_blocking(brg));
320328

329+
brg->src_sum_group_size = wei_d.dims()[1];
330+
if (brg->with_src_dyn_quant) {
331+
brg->src_sum_group_size = brg->rd_block;
332+
brg->src_grouped_sum_stride = div_up(wei_d.dims()[1], brg->src_sum_group_size);
333+
}
334+
321335
// avx2_vnni_2 kernel with xf16 data type requires blocked weights.
322336
if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16()
323337
&& brg->LDB % brg->ld_block > 0)

src/cpu/x64/brgemm/brgemm.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
175175
void *scratch = nullptr,
176176
const brgemm_dynamic_values_t *dynamic_values = nullptr,
177177
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
178-
const void *ptr_src_scales = nullptr, size_t ic = 0);
178+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
179179

180180
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
181181
///
@@ -205,7 +205,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
205205
void *scratch = nullptr,
206206
const brgemm_dynamic_values_t *dynamic_values = nullptr,
207207
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
208-
const void *ptr_src_scales = nullptr, size_t ic = 0);
208+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
209209

210210
/// Execute BRGEMM kernel (brgemm_addr version)
211211
///
@@ -234,7 +234,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
234234
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
235235
const brgemm_dynamic_values_t *dynamic_values = nullptr,
236236
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
237-
const void *ptr_src_scales = nullptr, size_t ic = 0);
237+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
238238

239239
/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
240240
///
@@ -267,7 +267,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
267267
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
268268
const brgemm_dynamic_values_t *dynamic_values = nullptr,
269269
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
270-
const void *ptr_src_scales = nullptr, size_t ic = 0);
270+
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);
271271

272272
/// AMX utilities: Creates a palette based on BRGEMM descriptor
273273
///

src/cpu/x64/brgemm/brgemm_types.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ struct brgemm_desc_t {
321321
bool with_src_dyn_quant = false;
322322
int src_scales_group_size = 0;
323323
int src_scales_stride = 0;
324+
int src_sum_group_size = 0;
325+
int src_grouped_sum_stride = 0;
324326

325327
bool is_row_major() const {
326328
assert(layout != brgemm_layout_undef);
@@ -500,6 +502,7 @@ struct brgemm_kernel_params_t {
500502
const void *ptr_wei_scales = nullptr;
501503
const void *ptr_wei_zero_points = nullptr;
502504
const void *ptr_src_scales = nullptr;
505+
const void *ptr_src_grouped_sum = nullptr;
503506
size_t ic;
504507
dim_t dynamic_LDA = 0;
505508
dim_t dynamic_LDB = 0;

src/cpu/x64/brgemm/brgemm_utils.cpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
231231
if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2;
232232
if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1;
233233
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1;
234-
if (brg->with_src_dyn_quant) max_bcast_block -= 2;
234+
if (brg->with_src_dyn_quant) max_bcast_block -= 1;
235235
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
236236

237237
max_bcast_block /= adj_ld_block2;
@@ -298,15 +298,23 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
298298
= (brg->is_f16 && brg->isa_impl == avx512_core_fp16)
299299
? 1
300300
: data_type_vnni_granularity(brg->dt_a);
301+
301302
int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4;
302-
if (brg->with_grouped_wei_decomp) {
303+
if (brg->with_grouped_wei_decomp && !brg->with_src_dyn_quant) {
303304
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
304305
min_group_size = nstl::min(min_group_size, brg->src_scales_group_size);
305306
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
306307
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
308+
brg->rd_block = rd_unroll * vnni_granularity;
309+
} else if (brg->with_src_dyn_quant) {
310+
brg->rd_block = 32;
311+
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
312+
min_group_size = nstl::min(min_group_size, brg->src_scales_group_size);
313+
brg->rd_block = nstl::min(brg->rd_block, min_group_size);
314+
} else {
315+
brg->rd_block = rd_unroll * vnni_granularity;
307316
}
308317

309-
brg->rd_block = rd_unroll * vnni_granularity;
310318
brg->rdb = brg->reduce_dim / brg->rd_block;
311319
brg->rdb_tail = brg->reduce_dim % brg->rd_block;
312320

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+48-48
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
203203
const reg64_t reg_aux_wei_zp = reg_rdb_loop;
204204
const reg64_t reg_ic = reg_rdb_loop;
205205
const reg64_t reg_src_scales = reg_rdb_loop;
206+
const reg64_t reg_src_grouped_sum = reg_rdb_loop;
206207
const reg64_t reg_tmp_read_values = reg_rdb_loop;
207208

208209
const reg64_t reg_aux_scales = reg_aux_B;
@@ -280,12 +281,13 @@ struct jit_brgemm_kernel_t : public jit_generator {
280281
constexpr static int reg_src_scales_offs_ = 336;
281282
constexpr static int reg_aux_src_scales_offs_ = 344;
282283
constexpr static int reg_aux2_src_scales_offs_ = 352;
283-
// constexpr static int stack_space_needed_ = 360;
284+
constexpr static int reg_src_grouped_sum_offs_ = 360;
285+
constexpr static int reg_aux_src_grouped_sum_offs_ = 368;
286+
constexpr static int reg_aux2_src_grouped_sum_offs_ = 376;
284287
// these are used for FP8 as temporary push/pop spaces
285-
constexpr static int reg_val_tmp_1_ = 368;
286-
constexpr static int reg_val_tmp_2_ = 376;
287-
constexpr static int stack_space_needed_ = 384;
288-
// regsiters for dynamic quant
288+
constexpr static int reg_val_tmp_1_ = 384;
289+
constexpr static int reg_val_tmp_2_ = 392;
290+
constexpr static int stack_space_needed_ = 400;
289291

290292

291293
bool is_ldb_loop_ = false;
@@ -323,7 +325,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
323325
}
324326

325327
if (brg.with_src_dyn_quant) {
326-
used_vregs += 2;
328+
used_vregs += 1;
327329
}
328330

329331
if (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
@@ -971,6 +973,12 @@ void jit_brgemm_kernel_t<Wmm>::copy_post_ops_stack_values_to_aux(
971973
mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
972974
mov(ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales);
973975
mov(ptr[rsp + reg_aux2_src_scales_offs_], reg_src_scales);
976+
977+
if (brg.with_wei_decomp_zero_points) {
978+
mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]);
979+
mov(ptr[rsp + reg_aux_src_grouped_sum_offs_], reg_src_grouped_sum);
980+
mov(ptr[rsp + reg_aux2_src_grouped_sum_offs_], reg_src_grouped_sum);
981+
}
974982
}
975983
if (brg.zp_type_b != brgemm_broadcast_t::none) {
976984
mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]);
@@ -1048,6 +1056,9 @@ void jit_brgemm_kernel_t<Wmm>::read_params() {
10481056
if (brg.with_src_dyn_quant) {
10491057
mov(reg_src_scales, ptr[param1 + GET_OFF(ptr_src_scales)]);
10501058
mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales);
1059+
1060+
mov(reg_src_grouped_sum, ptr[param1 + GET_OFF(ptr_src_grouped_sum)]);
1061+
mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum);
10511062
}
10521063

10531064
if (brg.zp_type_c != brgemm_broadcast_t::none) {
@@ -2296,21 +2307,10 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
22962307
};
22972308

22982309
auto vmm_zero_point = [&](int ld) {
2299-
int idx = isa_num_vregs(brg.isa_impl) - 3 - ld;
2310+
int idx = isa_num_vregs(brg.isa_impl) - 2 - ld;
23002311
return Vmm(idx);
23012312
};
23022313

2303-
static const int8_t negative_one[64] = {
2304-
-1, -1, -1, -1, -1, -1, -1, -1,
2305-
-1, -1, -1, -1, -1, -1, -1, -1,
2306-
-1, -1, -1, -1, -1, -1, -1, -1,
2307-
-1, -1, -1, -1, -1, -1, -1, -1,
2308-
-1, -1, -1, -1, -1, -1, -1, -1,
2309-
-1, -1, -1, -1, -1, -1, -1, -1,
2310-
-1, -1, -1, -1, -1, -1, -1, -1,
2311-
-1, -1, -1, -1, -1, -1, -1, -1
2312-
};
2313-
23142314
static const int8_t mask_low_half[64] = {
23152315
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
23162316
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
@@ -2328,33 +2328,18 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
23282328
if (brg.with_wei_decomp_zero_points) {
23292329
mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]);
23302330
if (brg.wei_decomp_zero_points_stride == 0) {
2331-
auto reg_ptr_8 = Reg8(reg_ptr.getIdx());
2332-
mov(reg_ptr_8, ptr[reg_local_wei_zp]);
2333-
uni_vpbroadcastb(vmm_zero_point(0), reg_ptr_8);
2331+
auto reg_ptr_32 = Reg32(reg_ptr.getIdx());
2332+
movzx(reg_ptr_32, ptr[reg_local_wei_zp]);
2333+
uni_vmovq(Xmm(vmm_zero_point(0).getIdx()), reg_ptr);
2334+
uni_vbroadcastss(vmm_zero_point(0), Xmm(vmm_zero_point(0).getIdx()));
23342335
} else {
2335-
static const int8_t index_table[64] = {
2336-
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C,
2337-
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C,
2338-
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C,
2339-
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C
2340-
};
2341-
2342-
auto vmm_indexes = Vmm(isa_num_vregs(brg.isa_impl) - 1);
2343-
mov(reg_ptr, (size_t)index_table);
2344-
uni_vmovups(vmm_indexes, ptr[reg_ptr]);
2345-
23462336
for (int ld = 0; ld < ld_block2; ld++) {
23472337
uni_vpmovzxbd(vmm_zero_point(ld), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]);
2348-
vpshufb(vmm_zero_point(ld), vmm_zero_point(ld), vmm_indexes);
23492338
}
23502339
}
23512340
}
23522341

2353-
auto vmm_neg_one = Vmm(isa_num_vregs(brg.isa_impl) - 1);
2354-
mov(reg_ptr, (size_t)negative_one);
2355-
uni_vmovups(vmm_neg_one, ptr[reg_ptr]);
2356-
2357-
auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 2);
2342+
auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 1);
23582343
mov(reg_ptr, (size_t)mask_low_half);
23592344
uni_vmovups(vmm_mask_low_half, ptr[reg_ptr]);
23602345

@@ -2409,22 +2394,28 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24092394
auto vmm = accm(ld_block2, bd, ld);
24102395
vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
24112396
}
2412-
if (brg.with_wei_decomp_zero_points) {
2413-
uni_vpxor(bcst(), bcst(), vmm_neg_one);
2414-
uni_vpsubb(bcst(), bcst(), vmm_neg_one);
2415-
for (int ld = 0; ld < ld_block2; ld++) {
2416-
auto vmm = accm(ld_block2, bd, ld);
2417-
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
2418-
vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
2419-
}
2420-
}
24212397
}
24222398
}
24232399

24242400
auto reg_local_src_scales = reg_local_wei_zp;
2401+
auto reg_local_src_grouped_sum = reg_local_wei_zp;
24252402
auto vmm_src_scales = bcst();
2426-
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
2403+
auto vmm_src_grouped_sum = bcst();
24272404

2405+
if (brg.with_wei_decomp_zero_points) {
2406+
mov(reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]);
2407+
for (int bd = bd_b; bd < bd_e; bd++) {
2408+
for (int ld = 0; ld < ld_block2; ld++) {
2409+
auto vmm_accm = accm(ld_block2, bd, ld);
2410+
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
2411+
uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]);
2412+
uni_vpmulld(vmm_src_grouped_sum, vmm_src_grouped_sum, vmm_zp);
2413+
uni_vpsubd(vmm_accm, vmm_accm, vmm_src_grouped_sum);
2414+
}
2415+
}
2416+
}
2417+
2418+
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
24282419
for (int bd = bd_b; bd < bd_e; bd++) {
24292420
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
24302421
for (int ld = 0; ld < ld_block2; ld++) {
@@ -3073,6 +3064,11 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
30733064
if (brg.with_src_dyn_quant) {
30743065
ic_group_shift(reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_,
30753066
brg.src_scales_group_size, sizeof(float));
3067+
3068+
if (brg.with_wei_decomp_zero_points) {
3069+
ic_group_shift(reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3070+
brg.src_sum_group_size, sizeof(int32_t));
3071+
}
30763072
}
30773073

30783074
mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
@@ -3306,6 +3302,10 @@ void jit_brgemm_kernel_t<Wmm>::bdb_loop() {
33063302
mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
33073303
add(reg_src_scales, bd_block2 * brg.bd_block * brg.src_scales_stride * sizeof(float));
33083304
mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales);
3305+
3306+
mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]);
3307+
add(reg_src_grouped_sum, bd_block2 * brg.bd_block * brg.src_grouped_sum_stride * sizeof(int32_t));
3308+
mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum);
33093309
}
33103310

33113311
advance_bd_block2_post_op_regs(bd_block2);

0 commit comments

Comments
 (0)