@@ -305,7 +305,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
305
305
used_vregs = 5 ;
306
306
else if (brg.is_f16_b_non_amx_vnni ())
307
307
used_vregs = 2 ;
308
-
308
+
309
309
if (one_of (brg.dt_b , data_type::nf4) && brg.isa_impl == avx2) {
310
310
used_vregs += 5 ;
311
311
}
@@ -2431,7 +2431,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2431
2431
for (int bd = bd_b; bd < bd_e; bd++) {
2432
2432
uni_vbroadcastss (vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof (float )]);
2433
2433
for (int ld = 0 ; ld < ld_block2; ld++) {
2434
- uni_vmovups (load (ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof (float )]);
2434
+ if (brg.wei_decomp_scales_stride == 0 ) {
2435
+ uni_vbroadcastss (load (ld), ptr[reg_local_wei_scales]);
2436
+ } else {
2437
+ uni_vmovups (load (ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof (float )]);
2438
+ }
2435
2439
}
2436
2440
for (int ld = 0 ; ld < ld_block2; ld++) {
2437
2441
auto vmm_accm_aux = vmm_accm_tmp (ld_block2, bd, ld);
@@ -2901,7 +2905,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
2901
2905
for (int ld = 0 ; ld < ld_block2; ld++) {
2902
2906
auto vmm_accm_tmp = accm_tmp (ld_block2, 0 , ld);
2903
2907
auto vmm_accm = accm (ld_block2, 0 , ld);
2904
- load_scales (bcst (), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size (brg.wei_decomp_scales_dt )]);
2908
+ if (brg.wei_decomp_scales_stride == 0 ) {
2909
+ load_scales (bcst (), ptr[reg_local_wei_scales]);
2910
+ } else {
2911
+ load_scales (bcst (), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size (brg.wei_decomp_scales_dt )]);
2912
+ }
2905
2913
uni_vfmadd231ps (vmm_accm, vmm_accm_tmp, bcst ());
2906
2914
}
2907
2915
}
@@ -3025,8 +3033,8 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
3025
3033
mov (reg_rdb_loop, brg.rdb );
3026
3034
L_aligned (rdb_loop_label, 64 );
3027
3035
{
3028
- if (brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
3029
- brg.wei_decomp_zero_points_stride != 0 )) {
3036
+ if (( brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
3037
+ brg.wei_decomp_zero_points_stride != 0 )) || brg. with_src_dyn_quant ) {
3030
3038
auto reg_local_ic = reg_aux_D;
3031
3039
auto reg_local_wei_params = reg_bdb_loop;
3032
3040
auto reg_local_ic_group = reg_ldb_loop;
0 commit comments