@@ -320,17 +320,13 @@ struct jit_brgemm_kernel_t : public jit_generator {
320
320
used_vregs += 1 ;
321
321
}
322
322
323
- if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 ) {
323
+ if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 && !brg. with_src_dyn_quant ) {
324
324
used_vregs += 1 ;
325
325
}
326
326
327
327
if (brg.with_src_dyn_quant ) {
328
328
used_vregs += 1 ;
329
329
}
330
-
331
- if (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ) {
332
- used_vregs += brg.ld_block2 ;
333
- }
334
330
return isa_num_vregs (brg.isa_impl ) - used_vregs;
335
331
}
336
332
@@ -2306,11 +2302,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2306
2302
if (brg.req_s8s8_compensation ) uni_vpaddb (v1, v1, vmm_inp_shift ());
2307
2303
};
2308
2304
2309
- auto vmm_zero_point = [&](int ld) {
2310
- int idx = isa_num_vregs (brg.isa_impl ) - 2 - ld;
2311
- return Vmm (idx);
2312
- };
2313
-
2314
2305
static const int8_t mask_low_half[64 ] = {
2315
2306
0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F ,
2316
2307
0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F ,
@@ -2321,30 +2312,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2321
2312
mov (ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
2322
2313
mov (ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
2323
2314
2324
- auto reg_local_wei_scales = reg_bdb_loop;
2325
- auto reg_local_wei_zp = reg_ldb_loop;
2326
- auto reg_ptr = reg_local_wei_scales;
2327
-
2328
- if (brg.with_wei_decomp_zero_points ) {
2329
- mov (reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]);
2330
- if (brg.wei_decomp_zero_points_stride == 0 ) {
2331
- auto reg_ptr_32 = Reg32 (reg_ptr.getIdx ());
2332
- movzx (reg_ptr_32, ptr[reg_local_wei_zp]);
2333
- uni_vmovq (Xmm (vmm_zero_point (0 ).getIdx ()), reg_ptr);
2334
- uni_vbroadcastss (vmm_zero_point (0 ), Xmm (vmm_zero_point (0 ).getIdx ()));
2335
- } else {
2336
- for (int ld = 0 ; ld < ld_block2; ld++) {
2337
- uni_vpmovzxbd (vmm_zero_point (ld), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size (brg.wei_decomp_zero_points_dt )]);
2338
- }
2339
- }
2340
- }
2341
-
2315
+ auto reg_ptr = reg_bdb_loop;
2342
2316
auto vmm_mask_low_half = Vmm (isa_num_vregs (brg.isa_impl ) - 1 );
2343
2317
mov (reg_ptr, (size_t )mask_low_half);
2344
2318
uni_vmovups (vmm_mask_low_half, ptr[reg_ptr]);
2345
2319
2346
- mov (reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);
2347
-
2348
2320
const int vec_size = vreg_traits<Vmm>::vlen;
2349
2321
auto accums_stack_space = bd_e * ld_block2 * vec_size;
2350
2322
sub (rsp, accums_stack_space);
@@ -2397,42 +2369,68 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2397
2369
}
2398
2370
}
2399
2371
2400
- auto reg_local_src_scales = reg_local_wei_zp;
2401
- auto reg_local_src_grouped_sum = reg_local_wei_zp;
2402
- auto vmm_src_scales = bcst ();
2372
+ auto vmm_zero_point = [&](int ld) {
2373
+ return load (ld);
2374
+ };
2375
+
2376
+ auto reg_local_wei_zp = reg_ldb_loop;
2377
+ auto reg_local_src_grouped_sum = reg_bdb_loop;
2378
+ auto vmm_tmp = Vmm (isa_num_vregs (brg.isa_impl ) - 1 );
2403
2379
auto vmm_src_grouped_sum = bcst ();
2404
2380
2405
2381
if (brg.with_wei_decomp_zero_points ) {
2382
+ mov (reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_ + accums_stack_space]);
2383
+ if (brg.wei_decomp_zero_points_stride == 0 ) {
2384
+ Vmm vmm_zp = vmm_zero_point (0 );
2385
+ auto reg_ptr_32 = Reg32 (reg_ptr.getIdx ());
2386
+ movzx (reg_ptr_32, ptr[reg_local_wei_zp]);
2387
+ uni_vmovq (Xmm (vmm_zp.getIdx ()), reg_ptr);
2388
+ uni_vbroadcastss (vmm_zp, Xmm (vmm_zp.getIdx ()));
2389
+ }
2390
+
2406
2391
mov (reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]);
2407
2392
for (int bd = bd_b; bd < bd_e; bd++) {
2393
+ uni_vbroadcastss (vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof (int32_t )]);
2408
2394
for (int ld = 0 ; ld < ld_block2; ld++) {
2409
- auto vmm_accm = accm (ld_block2, bd, ld);
2410
2395
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point (0 ) : vmm_zero_point (ld);
2411
- uni_vbroadcastss (vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof (int32_t )]);
2412
- uni_vpmulld (vmm_src_grouped_sum, vmm_src_grouped_sum, vmm_zp);
2413
- uni_vpsubd (vmm_accm, vmm_accm, vmm_src_grouped_sum);
2396
+ if (bd == bd_b && brg.wei_decomp_zero_points_stride != 0 ) {
2397
+ uni_vpmovzxbd (vmm_zp, ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size (brg.wei_decomp_zero_points_dt )]);
2398
+ }
2399
+
2400
+ auto vmm_accm = accm (ld_block2, bd, ld);
2401
+ uni_vpmulld (vmm_tmp, vmm_src_grouped_sum, vmm_zp);
2402
+ uni_vpsubd (vmm_accm, vmm_accm, vmm_tmp);
2414
2403
}
2415
2404
}
2416
2405
}
2417
2406
2407
+ auto wei_scale = [&](int ld) {
2408
+ return load (ld);
2409
+ };
2410
+
2411
+ auto reg_local_src_scales = reg_ldb_loop;
2412
+ auto reg_local_wei_scales = reg_bdb_loop;
2413
+ auto vmm_src_scales = bcst ();
2414
+
2415
+ mov (reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_ + accums_stack_space]);
2418
2416
mov (reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
2417
+ if (brg.wei_decomp_scales_stride == 0 ) {
2418
+ uni_vbroadcastss (wei_scale (0 ), ptr[reg_local_wei_scales]);
2419
+ }
2420
+
2419
2421
for (int bd = bd_b; bd < bd_e; bd++) {
2420
2422
uni_vbroadcastss (vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof (float )]);
2421
2423
for (int ld = 0 ; ld < ld_block2; ld++) {
2422
- if (brg.wei_decomp_scales_stride == 0 ) {
2423
- uni_vbroadcastss (load (ld), ptr[reg_local_wei_scales]);
2424
- } else {
2425
- uni_vmovups (load (ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof (float )]);
2424
+ auto vmm_wei_scale = brg.wei_decomp_scales_stride == 0 ? wei_scale (0 ) : wei_scale (ld);
2425
+ if (bd == bd_b && brg.wei_decomp_scales_stride != 0 ) {
2426
+ uni_vmovups (vmm_wei_scale, ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof (float )]);
2426
2427
}
2427
- }
2428
- for (int ld = 0 ; ld < ld_block2; ld++) {
2429
- auto vmm_accm = accm (ld_block2, bd, ld);
2430
2428
2429
+ auto vmm_accm = accm (ld_block2, bd, ld);
2431
2430
uni_vcvtdq2ps (vmm_accm, vmm_accm);
2432
- uni_vmulps (vmm_accm, vmm_accm, vmm_src_scales);
2433
- uni_vmulps (load (ld), vmm_accm, load (ld));
2431
+ uni_vmulps (vmm_tmp, vmm_accm, vmm_src_scales);
2434
2432
uni_vmovups (vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
2435
- uni_vaddps (vmm_accm, vmm_accm, load (ld) );
2433
+ uni_vfmadd231ps (vmm_accm, vmm_tmp, vmm_wei_scale );
2436
2434
}
2437
2435
}
2438
2436
0 commit comments