@@ -87,15 +87,6 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
87
87
const auto wei_zero_points_d = ctx.memory_mdw (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
88
88
int wei_scales_oc_stride = wei_scales_d.dims ()[0 ] > 1 ? 1 : 0 ;
89
89
int wei_zero_points_oc_stride = wei_zero_points_d.dims ()[0 ] > 1 ? 1 : 0 ;
90
- int wei_scales_ic_group_size, wei_zero_points_ic_group_size;
91
- if (jbgp.with_grouped_weights_decompression ) {
92
- int wei_scales_ic_group_num = wei_scales_d.dims ()[1 ];
93
- int wei_zero_points_ic_group_num = wei_zero_points_d.dims ()[1 ];
94
- wei_scales_ic_group_size = wei_scales_ic_group_num ? div_up (jbgp.ic , wei_scales_ic_group_num) : jbgp.ic ;
95
- wei_zero_points_ic_group_size = wei_zero_points_ic_group_num ? div_up (jbgp.ic , wei_zero_points_ic_group_num) : jbgp.ic ;
96
- } else {
97
- wei_scales_ic_group_size = wei_zero_points_ic_group_size = jbgp.ic ;
98
- }
99
90
100
91
const float *oscales = nullptr ;
101
92
if (jbgp.weights_decompression ) {
@@ -170,8 +161,6 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
170
161
const auto wei_ic_stride
171
162
= types::data_type_size (jbgp.wei_dt ) * weights_d.off_v (ic_dims);
172
163
173
- int typesize_scale = one_of (jbgp.wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
174
-
175
164
const auto ker = [&](int ithr_oc_mb, int nthr_oc_mb, int ithr_ic, int osb,
176
165
int osb_s, int ocb, int ocb_s, int icc, int icc_s,
177
166
bool copy_buffer_a, int &prev_ker_idx) {
@@ -269,7 +258,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
269
258
int brg_ker_idx = brgemm_inner_product_utils::get_brg_kernel_index (
270
259
is_bs_tail, kernel_init, is_os_tail, is_oc_tail, false );
271
260
auto brg_kernel = brg_kernels_[brg_ker_idx].get ();
272
- const int ic_blocks_per_batch = jbgp.K / jbgp.ic_block ;
261
+ const int ic_blocks_per_batch = div_up ( jbgp.K , jbgp.ic_block ) ;
273
262
const dim_t wei_cur_ocb
274
263
= get_blk_off (weights_d, jbgp.wei_dt , cur_ocb, 0 );
275
264
@@ -290,7 +279,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
290
279
ic + b * jbgp.K ));
291
280
addr_batch[b].ptr .A = A_ptr;
292
281
const dim_t wei_offset = (wei_cur_ocb
293
- + wei_ic_stride * (icb + b * ic_blocks_per_batch)) / typesize_scale ;
282
+ + wei_ic_stride * (icb + b * ic_blocks_per_batch));
294
283
if (jbgp.weights_compressed ) {
295
284
using comp_tile_len_type = int ;
296
285
const comp_tile_len_type *compressed_tile_lengths_ptr
@@ -311,30 +300,35 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
311
300
(*brg_decomp_kernel_)(&dcomp_params);
312
301
addr_batch[b].ptr .B = decomp_buf;
313
302
} else if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t ::prepack) {
314
- auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt );
303
+ int typesize_scale = one_of (jbgp.orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
304
+ auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt ) / typesize_scale;
315
305
auto weights_ptr = reinterpret_cast <const uint8_t *>(&weights[w_off]);
316
306
317
307
const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
318
308
auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr + wei_ic_stride * b * ic_blocks_per_batch;
319
309
320
- const int ic_internal_block = is_amx ? 2 : 1 ;
321
- auto wei_zero_points_ptr = wei_zero_points + oc;
322
- auto wei_scales_ptr = wei_scales + oc;
310
+ const int ic_internal_block = is_amx || one_of ( pd ()-> jbgp_ . orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
311
+ auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
312
+ auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
323
313
324
314
if (jbgp.with_grouped_weights_decompression ) {
325
315
weights_decompression_runtime_params_t rt_params = {};
326
316
auto ic_size = jbgp.ic_block * ic_blocks_per_batch / ic_internal_block;
327
- auto wei_scales_ic_group_size_local = wei_scales_ic_group_size / ic_internal_block;
328
- auto wei_zero_points_ic_group_size_local = wei_zero_points_ic_group_size / ic_internal_block;
317
+ auto wei_scales_ic_group_size_local = jbgp. wei_scales_ic_group_size / ic_internal_block;
318
+ auto wei_zero_points_ic_group_size_local = jbgp. wei_zero_points_ic_group_size / ic_internal_block;
329
319
auto group_size = nstl::min (wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local);
330
320
auto group_ic_blocks = div_up (ic_size, group_size);
321
+ auto start_group_scales = ic / jbgp.wei_scales_ic_group_size ;
322
+ auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size ;
331
323
for (int icb_idx = 0 ; icb_idx < group_ic_blocks; icb_idx++) {
332
324
auto ic_idx = icb_idx * group_size;
325
+ auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales;
326
+ auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points;
333
327
334
- rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt );
328
+ rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt ) / typesize_scale ;
335
329
rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block *jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
336
- rt_params.scales_ptr = wei_scales_ptr + (ic_idx * wei_scales_d.dims ()[0 ]) / wei_scales_ic_group_size_local ;
337
- rt_params.zero_points_ptr = wei_zero_points_ptr + (ic_idx * wei_zero_points_d.dims ()[0 ]) / wei_zero_points_ic_group_size_local ;
330
+ rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims ()[0 ];
331
+ rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims ()[0 ];
338
332
rt_params.ic_size = nstl::min (group_size, ic_size - icb_idx * group_size);
339
333
(*brg_weights_decomp_kernel_)(&rt_params);
340
334
}
@@ -350,15 +344,16 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
350
344
351
345
addr_batch[b].ptr .B = decomp_buf;
352
346
} else {
353
- addr_batch[b].ptr .B = weights + wei_offset;
347
+ int typesize_scale = one_of (jbgp.wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
348
+ addr_batch[b].ptr .B = weights + wei_offset / typesize_scale;
354
349
}
355
350
}
356
351
357
352
int wei_scales_offset = 0 ;
358
353
int wei_zero_points_offset = 0 ;
359
354
if (jbgp.weights_decompression ) {
360
- wei_scales_offset = (ic / wei_scales_ic_group_size) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
361
- wei_zero_points_offset = (ic / wei_zero_points_ic_group_size) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
355
+ wei_scales_offset = (ic / jbgp. wei_scales_ic_group_size ) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
356
+ wei_zero_points_offset = (ic / jbgp. wei_zero_points_ic_group_size ) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
362
357
}
363
358
364
359
auto ptr_D = dst + dst_off;
@@ -382,10 +377,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
382
377
383
378
brgemm_kernel_execute_postops (brg_kernel, gemm_batch,
384
379
addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data,
385
- scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
380
+ scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
386
381
} else {
387
382
brgemm_kernel_execute (brg_kernel, gemm_batch, addr_batch,
388
- (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
383
+ (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
389
384
}
390
385
}
391
386
@@ -403,33 +398,38 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
403
398
+ get_blk_off (src_d, jbgp.src_dt , n,
404
399
ic + ic_block * jbgp.ic_block );
405
400
const dim_t wei_offset
406
- = (wei_cur_ocb + wei_ic_stride * (icb + ic_block)) / typesize_scale ;
401
+ = (wei_cur_ocb + wei_ic_stride * (icb + ic_block));
407
402
408
403
if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t ::prepack) {
409
- auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt );
404
+ int typesize_scale = one_of (jbgp.orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
405
+ auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt ) / typesize_scale;
410
406
auto weights_ptr = reinterpret_cast <const uint8_t *>(&weights[w_off]);
411
407
412
408
const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
413
409
auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr;
414
410
415
- const int ic_internal_block = is_amx ? 2 : 1 ;
416
- auto wei_zero_points_ptr = wei_zero_points + oc;
417
- auto wei_scales_ptr = wei_scales + oc;
411
+ const int ic_internal_block = is_amx || one_of ( pd ()-> jbgp_ . orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
412
+ auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
413
+ auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
418
414
419
415
if (jbgp.with_grouped_weights_decompression ) {
416
+ weights_decompression_runtime_params_t rt_params = {};
420
417
auto ic_size = (jbgp.ic - (ic + ic_block * jbgp.ic_block )) / ic_internal_block;
421
- auto wei_scales_ic_group_size_local = wei_scales_ic_group_size / ic_internal_block;
422
- auto wei_zero_points_ic_group_size_local = wei_zero_points_ic_group_size / ic_internal_block;
418
+ auto wei_scales_ic_group_size_local = jbgp. wei_scales_ic_group_size / ic_internal_block;
419
+ auto wei_zero_points_ic_group_size_local = jbgp. wei_zero_points_ic_group_size / ic_internal_block;
423
420
auto group_size = nstl::min (wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local);
424
421
auto group_ic_blocks = div_up (ic_size, group_size);
425
- weights_decompression_runtime_params_t rt_params = {};
422
+ auto start_group_scales = ic / jbgp.wei_scales_ic_group_size ;
423
+ auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size ;
426
424
for (int icb_idx = 0 ; icb_idx < group_ic_blocks; icb_idx++) {
427
425
auto ic_idx = icb_idx * group_size;
426
+ auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales;
427
+ auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points;
428
428
429
- rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt );
429
+ rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt ) / typesize_scale ;
430
430
rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
431
- rt_params.scales_ptr = wei_scales_ptr + (ic_idx * wei_scales_d.dims ()[0 ]) / wei_scales_ic_group_size_local ;
432
- rt_params.zero_points_ptr = wei_zero_points_ptr + (ic_idx * wei_zero_points_d.dims ()[0 ]) / wei_zero_points_ic_group_size_local ;
431
+ rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims ()[0 ];
432
+ rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims ()[0 ];
433
433
rt_params.ic_size = nstl::min (group_size, ic_size - icb_idx * group_size);
434
434
(*brg_weights_decomp_kernel_)(&rt_params);
435
435
}
@@ -445,14 +445,15 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
445
445
446
446
addr_batch[0 ].ptr .B = decomp_buf;
447
447
} else {
448
- addr_batch[0 ].ptr .B = weights + wei_offset;
448
+ int typesize_scale = one_of (jbgp.wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
449
+ addr_batch[0 ].ptr .B = weights + wei_offset / typesize_scale;
449
450
}
450
451
451
452
int wei_scales_offset = 0 ;
452
453
int wei_zero_points_offset = 0 ;
453
454
if (jbgp.weights_decompression ) {
454
- wei_scales_offset = (ic / wei_scales_ic_group_size) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
455
- wei_zero_points_offset = (ic / wei_zero_points_ic_group_size) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
455
+ wei_scales_offset = (ic / jbgp. wei_scales_ic_group_size ) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
456
+ wei_zero_points_offset = (ic / jbgp. wei_zero_points_ic_group_size ) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
456
457
}
457
458
458
459
auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get ();
@@ -474,10 +475,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
474
475
nullptr , false , 1 , false , false , dst_scales};
475
476
476
477
brgemm_kernel_execute_postops (brg_kernel_ic_tail, 1 , addr_batch,
477
- (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
478
+ (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
478
479
} else {
479
480
brgemm_kernel_execute (brg_kernel_ic_tail, 1 , addr_batch,
480
- (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
481
+ (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
481
482
}
482
483
}
483
484
};
0 commit comments