Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FORK][FEATURE] DQ IP: performance enhansments #272

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ enum {
key_decompression_zero_points,
key_src_quantized,
key_src_dequantized_scales,
key_src_grouped_sum,
// These two keys should always be the last ones,
// even though they are not in alphabetical order
key_nested,
Expand Down
22 changes: 18 additions & 4 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ void brgemm_desc_t::cleanup_dst_md() {
void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand All @@ -105,6 +106,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;

assert(brg_kernel);
Expand All @@ -116,7 +118,8 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
const void *addr_A, const void *addr_B,
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand All @@ -133,6 +136,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;
if (dynamic_values) {
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
Expand All @@ -148,7 +152,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand Down Expand Up @@ -178,6 +183,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;
if (dynamic_values) {
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
Expand All @@ -194,7 +200,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand Down Expand Up @@ -224,6 +231,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;
if (dynamic_values) {
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
Expand Down Expand Up @@ -318,6 +326,12 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,

CHECK(brgemm_blocking(brg));

brg->src_sum_group_size = wei_d.dims()[1];
if (brg->with_src_dyn_quant) {
brg->src_sum_group_size = brg->rd_block;
brg->src_grouped_sum_stride = div_up(wei_d.dims()[1], brg->src_sum_group_size);
}

// avx2_vnni_2 kernel with xf16 data type requires blocked weights.
if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16()
&& brg->LDB % brg->ld_block > 0)
Expand Down
8 changes: 4 additions & 4 deletions src/cpu/x64/brgemm/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
///
Expand Down Expand Up @@ -205,7 +205,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// Execute BRGEMM kernel (brgemm_addr version)
///
Expand Down Expand Up @@ -234,7 +234,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
///
Expand Down Expand Up @@ -267,7 +267,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// AMX utilities: Creates a palette based on BRGEMM descriptor
///
Expand Down
3 changes: 3 additions & 0 deletions src/cpu/x64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ struct brgemm_desc_t {
bool with_src_dyn_quant = false;
int src_scales_group_size = 0;
int src_scales_stride = 0;
int src_sum_group_size = 0;
int src_grouped_sum_stride = 0;

bool is_row_major() const {
assert(layout != brgemm_layout_undef);
Expand Down Expand Up @@ -500,6 +502,7 @@ struct brgemm_kernel_params_t {
const void *ptr_wei_scales = nullptr;
const void *ptr_wei_zero_points = nullptr;
const void *ptr_src_scales = nullptr;
const void *ptr_src_grouped_sum = nullptr;
size_t ic;
dim_t dynamic_LDA = 0;
dim_t dynamic_LDB = 0;
Expand Down
19 changes: 11 additions & 8 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,10 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5;
if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2;
if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1;
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1;
if (brg->with_src_dyn_quant) max_bcast_block -= 2;
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0 && !brg->with_src_dyn_quant) max_bcast_block -= 1;
if (brg->with_src_dyn_quant) max_bcast_block -= 1;

max_bcast_block /= adj_ld_block2;
if (brg->with_src_dyn_quant) {
max_bcast_block /= 2;
}

return max_bcast_block;
}
Expand Down Expand Up @@ -301,15 +297,22 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
= (brg->is_f16 && brg->isa_impl == avx512_core_fp16)
? 1
: data_type_vnni_granularity(brg->dt_a);

int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4;
if (brg->with_grouped_wei_decomp) {
if (brg->with_grouped_wei_decomp && !brg->with_src_dyn_quant) {
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
min_group_size = nstl::min(min_group_size, brg->src_scales_group_size);
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
brg->rd_block = rd_unroll * vnni_granularity;
} else if (brg->with_src_dyn_quant) {
brg->rd_block = brg->src_scales_group_size;
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
brg->rd_block = nstl::min(brg->rd_block, min_group_size);
} else {
brg->rd_block = rd_unroll * vnni_granularity;
}

brg->rd_block = rd_unroll * vnni_granularity;
brg->rdb = brg->reduce_dim / brg->rd_block;
brg->rdb_tail = brg->reduce_dim % brg->rd_block;

Expand Down
Loading