Skip to content

Commit 77250c1

Browse files
[FORK][FEATURE] DQ IP: optimize pointer arithmetic
1 parent a870aae commit 77250c1

File tree

2 files changed

+163
-58
lines changed

2 files changed

+163
-58
lines changed

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+147-57
Original file line numberDiff line numberDiff line change
@@ -2997,6 +2997,95 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
29972997
int ld_block2, int ldb_loop_length, bool is_reg_tail, bool is_ld_tail,
29982998
bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail,
29992999
bool skip_accumulation) {
3000+
auto ic_group_shift_generic = [&]() {
3001+
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0))
3002+
|| brg.with_src_dyn_quant) {
3003+
auto reg_local_ic = reg_aux_D;
3004+
auto reg_local_wei_params = reg_bdb_loop;
3005+
auto reg_local_ic_group = reg_ldb_loop;
3006+
3007+
auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) {
3008+
mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3009+
mov(reg_local_ic_group, group_size);
3010+
xor_(rdx, rdx);
3011+
idiv(reg_local_ic_group);
3012+
imul(reg_local_ic, reg_local_ic, stride);
3013+
3014+
mov(reg_local_wei_params, ptr[rsp + src_offs]);
3015+
add(reg_local_wei_params, reg_local_ic);
3016+
mov(ptr[rsp + dst_offs], reg_local_wei_params);
3017+
};
3018+
3019+
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
3020+
mov(ptr[rsp + reg_aux2_D_offs_], reg_aux_D);
3021+
mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
3022+
mov(ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv
3023+
3024+
if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) {
3025+
ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_,
3026+
brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt));
3027+
}
3028+
3029+
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
3030+
ic_group_shift(reg_aux_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_,
3031+
brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt));
3032+
}
3033+
3034+
if (brg.with_src_dyn_quant) {
3035+
ic_group_shift(reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_,
3036+
brg.src_scales_group_size, sizeof(float));
3037+
3038+
if (brg.with_wei_decomp_zero_points) {
3039+
ic_group_shift(reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3040+
brg.src_sum_group_size, sizeof(int32_t));
3041+
}
3042+
}
3043+
3044+
mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3045+
add(reg_local_ic, brg.rd_block);
3046+
mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic);
3047+
3048+
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3049+
mov(reg_aux_D, ptr[rsp + reg_aux2_D_offs_]);
3050+
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
3051+
mov(reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]);
3052+
}
3053+
};
3054+
3055+
auto ic_group_shift_opt = [&](int rb) {
3056+
mov(ptr[rsp + reg_bdb_loop_offs_], reg_rdb_loop);
3057+
auto reg_ptr = reg_rdb_loop;
3058+
3059+
auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) {
3060+
if ((rb + 1) * brg.rd_block % group_size == 0) {
3061+
mov(reg_ptr, ptr[rsp + src_offs]);
3062+
add(reg_ptr, stride);
3063+
mov(ptr[rsp + dst_offs], reg_ptr);
3064+
}
3065+
};
3066+
3067+
if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) {
3068+
ic_group_shift(reg_aux2_wei_scales_offs_, reg_aux2_wei_scales_offs_,
3069+
brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt));
3070+
}
3071+
3072+
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
3073+
ic_group_shift(reg_aux2_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_,
3074+
brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt));
3075+
}
3076+
3077+
if (brg.with_src_dyn_quant) {
3078+
ic_group_shift(reg_aux2_src_scales_offs_, reg_aux2_src_scales_offs_,
3079+
brg.src_scales_group_size, sizeof(float));
3080+
3081+
if (brg.with_wei_decomp_zero_points) {
3082+
ic_group_shift(reg_aux2_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3083+
brg.src_sum_group_size, sizeof(int32_t));
3084+
}
3085+
}
3086+
3087+
mov(reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3088+
};
30003089

30013090
Label ldb_loop_label;
30023091
Label BS_loop_label;
@@ -3023,75 +3112,76 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
30233112
gemm_microkernel_amx(
30243113
bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail);
30253114
} else {
3026-
if (brg.rdb > 0) {
3027-
Label rdb_loop_label;
3028-
mov(reg_rdb_loop, brg.rdb);
3029-
L_aligned(rdb_loop_label, 64);
3030-
{
3031-
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0))
3032-
|| brg.with_src_dyn_quant) {
3033-
auto reg_local_ic = reg_aux_D;
3034-
auto reg_local_wei_params = reg_bdb_loop;
3035-
auto reg_local_ic_group = reg_ldb_loop;
3036-
3037-
auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) {
3038-
mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3039-
mov(reg_local_ic_group, group_size);
3040-
xor_(rdx, rdx);
3041-
idiv(reg_local_ic_group);
3042-
imul(reg_local_ic, reg_local_ic, stride);
3043-
3044-
mov(reg_local_wei_params, ptr[rsp + src_offs]);
3045-
add(reg_local_wei_params, reg_local_ic);
3046-
mov(ptr[rsp + dst_offs], reg_local_wei_params);
3047-
};
3048-
3049-
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
3050-
mov(ptr[rsp + reg_aux2_D_offs_], reg_aux_D);
3051-
mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
3052-
mov(ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv
3053-
3054-
if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) {
3055-
ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_,
3056-
brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt));
3057-
}
3115+
ic_group_shift_generic();
3116+
3117+
if (brg.with_src_dyn_quant) {
3118+
auto rdb_group = brg.rd_block;
3119+
auto rd_size = brg.rdb * brg.rd_block + brg.rdb_tail;
3120+
if (brg.wei_decomp_scales_group_size < rd_size)
3121+
rdb_group = nstl::max(rdb_group, brg.wei_decomp_scales_group_size);
3122+
if (brg.wei_decomp_zero_points_group_size < rd_size)
3123+
rdb_group = nstl::max(rdb_group, brg.wei_decomp_zero_points_group_size);
3124+
if (brg.with_src_dyn_quant) {
3125+
rdb_group = nstl::max(rdb_group, brg.src_scales_group_size);
3126+
if (brg.with_wei_decomp_zero_points) {
3127+
rdb_group = nstl::max(rdb_group, brg.src_sum_group_size);
3128+
}
3129+
}
3130+
rdb_group = rdb_group / brg.rd_block;
3131+
auto rbd_blocks = brg.rdb / rdb_group;
30583132

3059-
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
3060-
ic_group_shift(reg_aux_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_,
3061-
brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt));
3062-
}
3133+
if (rbd_blocks > 0) {
3134+
Label rdb_loop_label;
3135+
mov(reg_rdb_loop, rbd_blocks);
3136+
L_aligned(rdb_loop_label, 64);
3137+
{
3138+
for (int rb = 0; rb < rdb_group; rb++) {
3139+
gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, false,
3140+
is_ld_tail, vpad, rows_for_rd_tail);
30633141

3064-
if (brg.with_src_dyn_quant) {
3065-
ic_group_shift(reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_,
3066-
brg.src_scales_group_size, sizeof(float));
3142+
add(reg_aux_A, rdb_A_offset());
3143+
add(reg_aux_B, rdb_B_offset());
30673144

3068-
if (brg.with_wei_decomp_zero_points) {
3069-
ic_group_shift(reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3070-
brg.src_sum_group_size, sizeof(int32_t));
3071-
}
3145+
ic_group_shift_opt(rb);
30723146
}
30733147

3074-
mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3075-
add(reg_local_ic, brg.rd_block);
3076-
mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic);
3077-
3078-
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3079-
mov(reg_aux_D, ptr[rsp + reg_aux2_D_offs_]);
3080-
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
3081-
mov(reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]);
3148+
dec(reg_rdb_loop);
3149+
cmp(reg_rdb_loop, 0);
30823150
}
3151+
jg(rdb_loop_label, T_NEAR);
3152+
}
30833153

3084-
const bool is_rd_tail = false;
3085-
gemm_microkernel(bd_block2, is_bdb_tail, ld_block2,
3086-
is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail);
3154+
for (int rb = rbd_blocks * rdb_group; rb < brg.rdb; rb++) {
3155+
gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, false,
3156+
is_ld_tail, vpad, rows_for_rd_tail);
30873157

30883158
add(reg_aux_A, rdb_A_offset());
30893159
add(reg_aux_B, rdb_B_offset());
30903160

3091-
dec(reg_rdb_loop);
3092-
cmp(reg_rdb_loop, 0);
3161+
ic_group_shift_opt(rb);
3162+
3163+
mov(reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3164+
}
3165+
} else {
3166+
if (brg.rdb > 0) {
3167+
Label rdb_loop_label;
3168+
mov(reg_rdb_loop, brg.rdb);
3169+
L_aligned(rdb_loop_label, 64);
3170+
{
3171+
const bool is_rd_tail = false;
3172+
gemm_microkernel(bd_block2, is_bdb_tail, ld_block2,
3173+
is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail);
3174+
3175+
add(reg_aux_A, rdb_A_offset());
3176+
add(reg_aux_B, rdb_B_offset());
3177+
3178+
ic_group_shift_generic();
3179+
3180+
dec(reg_rdb_loop);
3181+
cmp(reg_rdb_loop, 0);
3182+
}
3183+
jg(rdb_loop_label, T_NEAR);
30933184
}
3094-
jg(rdb_loop_label, T_NEAR);
30953185
}
30963186
}
30973187
if (brg.rdb_tail != 0) {

src/cpu/x64/jit_brgemm_inner_product_utils.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,22 @@ status_t jit_brgemm_ip_fwd_conf_t::init_conf(cpu_isa_t isa,
699699

700700
// Current implementation of grouped weights decompression algorithm requires K size to be aligned on group size.
701701
// Besides that "batched" usage of brgemm block is not covered, so forcing the value to 1.
702-
if (jbgp.with_grouped_weights_decompression || jbgp.with_src_dynamic_quant) {
702+
if (jbgp.with_src_dynamic_quant) {
703+
size_t max_ic_group_size = k_blk;
704+
if (jbgp.wei_scales_ic_group_size != static_cast<size_t>(jbgp.ic))
705+
max_ic_group_size = std::max(max_ic_group_size, jbgp.wei_scales_ic_group_size);
706+
if (jbgp.wei_zero_points_ic_group_size != static_cast<size_t>(jbgp.ic))
707+
max_ic_group_size = std::max(max_ic_group_size, jbgp.wei_zero_points_ic_group_size);
708+
max_ic_group_size = std::max(max_ic_group_size, jbgp.src_quant_group_size);
709+
max_ic_group_size = std::max(max_ic_group_size, jbgp.src_sum_group_size);
710+
711+
if ((jbgp.nb_ic_blocking * k_blk) % max_ic_group_size != 0) {
712+
jbgp.nb_ic_blocking = max_ic_group_size;
713+
}
714+
jbgp.K = k_blk * jbgp.nb_ic_blocking;
715+
jbgp.gemm_batch_size = 1;
716+
jbgp.nthr_ic_b = 1;
717+
} else if (jbgp.with_grouped_weights_decompression) {
703718
auto min_ic_group_size = std::min(jbgp.wei_scales_ic_group_size, jbgp.wei_zero_points_ic_group_size);
704719
min_ic_group_size = std::min(min_ic_group_size, jbgp.src_quant_group_size);
705720
if ((jbgp.nb_ic_blocking * k_blk) % min_ic_group_size != 0) {

0 commit comments

Comments
 (0)