Skip to content

Commit ff65080

Browse files
committed
x64: conv: brdgmm: enable zps per group
1 parent cc35f50 commit ff65080

6 files changed

+128
-44
lines changed

src/common/primitive_attr_quant.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ struct zero_points_t : public c_compatible {
241241

242242
// arg-specific checks
243243
bool common(int arg) const { return get_mask(arg) == 0; }
244+
bool per_dim_1(int arg) const { return get_mask(arg) == 2; }
244245
bool has_default_values(int arg) const {
245246
return is_set(arg) == false && has_default_data_type(arg);
246247
}

src/cpu/x64/brgemm/brgemm.cpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
205205
brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
206206
brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations;
207207
brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
208+
brgemm_p.a_zp_values = post_ops_data.a_zp_values;
208209
brgemm_p.c_zp_values = post_ops_data.c_zp_values;
209210
brgemm_p.ptr_dst_scales = post_ops_data.dst_scales;
210211
if (dynamic_values) {
@@ -458,19 +459,30 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg,
458459
auto zero_points = attr->zero_points_;
459460

460461
// common zero point type is supported for now
461-
if (!zero_points.common(mem_arg)) return status::unimplemented;
462+
const bool is_per_dim_1_bcast = zero_points.per_dim_1(mem_arg);
463+
const bool is_common_bcast = zero_points.common(mem_arg);
464+
if (!is_common_bcast && !is_per_dim_1_bcast)
465+
return status::unimplemented;
462466

463467
const bool skip_zero_point
464468
= mem_arg == DNNL_ARG_WEIGHTS && brg->skip_zp_b_compensation;
465-
zp_type = zero_points.has_default_values(mem_arg) || skip_zero_point
466-
? brgemm_broadcast_t::none
467-
: brgemm_broadcast_t::per_tensor;
469+
470+
zp_type = brgemm_broadcast_t::none;
471+
const bool is_any_bcast
472+
= !(zero_points.has_default_values(mem_arg) || skip_zero_point);
473+
if (is_any_bcast) {
474+
if (is_common_bcast)
475+
zp_type = brgemm_broadcast_t::per_tensor;
476+
else if (is_per_dim_1_bcast)
477+
zp_type = brgemm_broadcast_t::per_n;
478+
}
479+
468480
return status::success;
469481
};
470482

471-
init_zp_type(brg->zp_type_a, DNNL_ARG_SRC);
472-
init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS);
473-
init_zp_type(brg->zp_type_c, DNNL_ARG_DST);
483+
CHECK(init_zp_type(brg->zp_type_a, DNNL_ARG_SRC));
484+
CHECK(init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS));
485+
CHECK(init_zp_type(brg->zp_type_c, DNNL_ARG_DST));
474486

475487
// Post-ops may use vector registers so brgemm/brdgmm blocking may need to
476488
// be updated

src/cpu/x64/brgemm/brgemm_types.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ struct brgemm_kernel_params_t {
494494

495495
const void *a_zp_compensations = nullptr;
496496
const void *b_zp_compensations = nullptr;
497+
const void *a_zp_values = nullptr;
497498
const void *c_zp_values = nullptr;
498499
size_t skip_accm = 0;
499500
int32_t zp_a_val = 1;
@@ -604,7 +605,8 @@ struct brgemm_post_ops_data_t {
604605
const void *b_zp_compensations = nullptr,
605606
const void *c_zp_values = nullptr, bool skip_accumulation = false,
606607
int32_t zp_a_val = 1, bool do_only_comp = false,
607-
bool do_only_zp_a_val = false, const float *dst_scales = nullptr)
608+
bool do_only_zp_a_val = false, const float *dst_scales = nullptr,
609+
const void *a_zp_values = nullptr)
608610
: bias(bias)
609611
, scales(scales)
610612
, binary_post_ops_rhs(binary_post_ops_rhs)
@@ -619,7 +621,8 @@ struct brgemm_post_ops_data_t {
619621
, zp_a_val {zp_a_val}
620622
, do_only_comp {do_only_comp}
621623
, do_only_zp_a_val {do_only_zp_a_val}
622-
, dst_scales(dst_scales) {}
624+
, dst_scales(dst_scales)
625+
, a_zp_values(a_zp_values) {}
623626

624627
const void *bias = nullptr;
625628
const float *scales = nullptr;
@@ -636,6 +639,7 @@ struct brgemm_post_ops_data_t {
636639
const bool do_only_comp = false;
637640
const bool do_only_zp_a_val = false;
638641
const float *dst_scales = nullptr;
642+
const void *a_zp_values = nullptr;
639643
};
640644

641645
} // namespace x64

src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp

+90-30
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ jit_brdgmm_kernel_base_t<Wmm>::jit_brdgmm_kernel_base_t(
4545
, max_vmms_(isa_num_vregs(brg.isa_impl))
4646
, compute_dst_zp_(brg.zp_type_c != brgemm_broadcast_t::none)
4747
, compute_src_zp_(brg.zp_type_a != brgemm_broadcast_t::none)
48+
, is_src_zp_bcast_(brg.zp_type_a == brgemm_broadcast_t::per_tensor)
4849
, compute_compensation_(compute_src_zp_ || brg.req_s8s8_compensation)
4950
, has_vpad_(brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0)
5051
, has_bpad_(brg.brgattr.max_top_bpad > 0 || brg.brgattr.max_bottom_bpad > 0)
@@ -147,7 +148,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::read_params() {
147148
}
148149

149150
if (compute_src_zp_) {
150-
mov(reg_tmp, ptr[param1 + GET_OFF(zp_a_val)]);
151+
mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_values)]);
151152
mov(ptr[rsp + src_zp_value_], reg_tmp);
152153

153154
mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_compensations)]);
@@ -609,6 +610,17 @@ void jit_brdgmm_kernel_base_t<Wmm>::maybe_transpose_interleaved_vnni_to_plain(
609610
}
610611
}
611612

613+
template <typename Wmm>
614+
void jit_brdgmm_kernel_base_t<Wmm>::load_src_zp() {
615+
mov(reg_src_zero_point, ptr[rsp + src_zp_value_]);
616+
lea(reg_src_zero_point,
617+
is_src_zp_bcast_
618+
? ptr_b[reg_src_zero_point]
619+
: ptr[reg_src_zero_point + reg_aux_N * sizeof(int32_t)]);
620+
if (!is_superset(brg.isa_impl, avx512_core) && is_src_zp_bcast_)
621+
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
622+
}
623+
612624
template <typename Wmm>
613625
void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
614626
int m_blocks, int n_blocks, bool has_n_tail) {
@@ -620,12 +632,10 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
620632
lea(reg_s8s8_comp, ptr[reg_s8s8_comp + reg_aux_N * sizeof(int32_t)]);
621633
}
622634
if (compute_src_zp_) {
623-
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
635+
load_src_zp();
624636
mov(reg_zp_compensation, ptr[rsp + zp_compensation_]);
625637
lea(reg_zp_compensation,
626638
ptr[reg_zp_compensation + reg_aux_N * sizeof(int32_t)]);
627-
if (!is_superset(brg.isa_impl, avx512_core))
628-
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
629639
}
630640

631641
for_(int v_i = 0; v_i < v_substep; ++v_i)
@@ -640,16 +650,35 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
640650
}
641651
if (compute_src_zp_) {
642652
// zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
643-
const Vmm vmm_zp = vmm_zp_comp();
644-
vmovups(vmm_zp,
645-
maybe_EVEX_compress_addr(reg_zp_compensation, offset));
646-
if (is_superset(brg.isa_impl, avx512_core)) {
647-
const bool src_zp_is_common = true;
648-
vpmulld(vmm_zp, vmm_zp,
649-
maybe_EVEX_compress_addr(
650-
reg_src_zero_point, 0, src_zp_is_common));
653+
const bool is_tail
654+
= n + 1 == n_blocks && has_n_tail && substep_simd < simd_w_;
655+
const Vmm vmm_zp = isa_has_masks(brg.isa_impl)
656+
? maybe_mask(vmm_zp_comp(), is_tail, false)
657+
: vmm_zp_comp();
658+
if (IMPLICATION(is_tail, isa_has_masks(brg.isa_impl))) {
659+
vmovups(vmm_zp,
660+
maybe_EVEX_compress_addr(reg_zp_compensation, offset));
661+
if (is_src_zp_bcast_) {
662+
if (is_superset(brg.isa_impl, avx512_core))
663+
vpmulld(vmm_zp, vmm_zp,
664+
maybe_EVEX_compress_addr(
665+
reg_src_zero_point, 0, true));
666+
else
667+
vpmulld(vmm_zp, vmm_zp, vmm_bcast());
668+
} else
669+
vpmulld(vmm_zp, vmm_zp,
670+
maybe_EVEX_compress_addr(
671+
reg_src_zero_point, offset));
651672
} else {
652-
vpmulld(vmm_zp, vmm_zp, vmm_bcast());
673+
const int tail_size = tail_length();
674+
const Vmm ymm_tmp
675+
= vmm_bcast(); // used for bcast or tail processing in avx2
676+
load_data(data_type::s32, vmm_zp,
677+
ptr[reg_zp_compensation + offset], tail_size);
678+
if (!is_src_zp_bcast_)
679+
load_data(data_type::s32, ymm_tmp,
680+
ptr[reg_src_zero_point + offset], tail_size);
681+
vpmulld(vmm_zp, vmm_zp, ymm_tmp);
653682
}
654683
}
655684
for (int m = 0; m < m_blocks; m++) {
@@ -795,24 +824,48 @@ void jit_brdgmm_kernel_base_t<Wmm>::load_b(
795824

796825
template <typename Wmm>
797826
void jit_brdgmm_kernel_base_t<Wmm>::comp_dot_product(
798-
compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb) {
827+
compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb, int n,
828+
bool is_tail_block) {
799829
switch (kernel_type) {
800830
case compute_pad_kernel_t::s8s8_kernel:
801831
vpdpbusd(vmm_acc, vmm_shift(), vmmb,
802832
is_superset(brg.isa_impl, avx512_core)
803833
? Xbyak::EvexEncoding
804834
: Xbyak::VexEncoding);
805835
break;
806-
case compute_pad_kernel_t::zero_point_kernel:
807-
if (is_superset(brg.isa_impl, avx512_core)) {
808-
vpmulld(vmm_zp_comp(), vmmb,
809-
maybe_EVEX_compress_addr(reg_src_zero_point, 0, true));
836+
case compute_pad_kernel_t::zero_point_kernel: {
837+
const Vmm vmm_zp = isa_has_masks(brg.isa_impl)
838+
? maybe_mask(vmm_zp_comp(), is_tail_block, false)
839+
: vmm_zp_comp();
840+
const size_t offset = comp_offset(n);
841+
if (IMPLICATION(is_tail_block, isa_has_masks(brg.isa_impl))) {
842+
if (is_src_zp_bcast_) {
843+
if (is_superset(brg.isa_impl, avx512_core))
844+
vpmulld(vmm_zp, vmmb,
845+
maybe_EVEX_compress_addr(
846+
reg_src_zero_point, 0, true));
847+
else
848+
vpmulld(vmm_zp, vmmb, vmm_bcast());
849+
} else {
850+
const Xbyak::Address src_zp_addr = maybe_EVEX_compress_addr(
851+
reg_src_zero_point, offset);
852+
if (is_fast_vnni_int8()) {
853+
vmovups(vmm_zp, src_zp_addr);
854+
vpermd(vmm_zp, vmm_permute(), vmm_zp);
855+
vpmulld(vmm_zp, vmmb, vmm_zp);
856+
} else
857+
vpmulld(vmm_zp, vmmb, src_zp_addr);
858+
}
810859
} else {
811-
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
812-
vpmulld(vmm_zp_comp(), vmmb, vmm_bcast());
860+
const Vmm ymm_tmp
861+
= vmm_bcast(); // used for bcast or tail processing in avx2
862+
if (!is_src_zp_bcast_)
863+
load_data(data_type::s32, ymm_tmp,
864+
ptr[reg_src_zero_point + offset], tail_length());
865+
vpmulld(vmm_zp, vmmb, ymm_tmp);
813866
}
814867
vpaddd(vmm_acc, vmm_acc, vmm_zp_comp());
815-
break;
868+
} break;
816869
default: assert(!"unsupported comp_kernel type");
817870
}
818871
}
@@ -853,21 +906,25 @@ void jit_brdgmm_kernel_base_t<Wmm>::pad_comp_kernel(
853906

854907
for (int pad_i = max_m_unroll; pad_i > 0; --pad_i) {
855908
L(jmp_table_labels[pad_i]);
856-
if (is_zero_point_kernel)
857-
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
909+
if (is_zero_point_kernel) load_src_zp();
858910
if (pad_i > m_blocks) continue;
859911
const int m_i = get_mi(pad_i);
860912
int p_b_i = 0;
861913
for (int n_i = 0; n_i < n_blocks; ++n_i, ++p_b_i) {
862-
if (get_substep_simd(n_i, 0, has_tail) <= 0) continue;
914+
const int substep_simd = get_substep_simd(n_i, 0, has_tail);
915+
if (substep_simd <= 0) continue;
863916
const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0);
917+
const bool is_tail_block
918+
= n_i + 1 == n_blocks && has_tail && substep_simd < simd_w_;
864919
if (p_b_i < n_preload_b_vmms) {
865-
comp_dot_product(kernel_type, vmm_acc, vmm_b(p_b_i));
920+
comp_dot_product(
921+
kernel_type, vmm_acc, vmm_b(p_b_i), n_i, is_tail_block);
866922
} else {
867923
// preloaded vmm_b not available
868924
const Vmm vmm_wei = vmm_b(max_bvmms - 1);
869925
load_b(vmm_wei, n_i, 0, has_tail, load_broadcast_wei);
870-
comp_dot_product(kernel_type, vmm_acc, vmm_wei);
926+
comp_dot_product(
927+
kernel_type, vmm_acc, vmm_wei, n_i, is_tail_block);
871928
}
872929
}
873930
}
@@ -885,8 +942,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
885942
auto kernel_body = [&](compute_pad_kernel_t kernel_type) {
886943
const bool is_zero_point_kernel
887944
= kernel_type == compute_pad_kernel_t::zero_point_kernel;
888-
if (is_zero_point_kernel)
889-
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
945+
if (is_zero_point_kernel) load_src_zp();
890946
for (int nb_i = 0; nb_i < n_blocks; nb_i += max_bvmms) {
891947
const int n_e = nstl::min(nb_i + max_bvmms, n_blocks) - nb_i;
892948
for (int i = 0; i < n_e; ++i) {
@@ -898,9 +954,13 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
898954
for_(int m_i = 0; m_i < m_blocks; ++m_i)
899955
for (int i = 0; i < n_e; ++i) {
900956
const int n_i = nb_i + i;
901-
if (get_substep_simd(n_i, 0, has_tail) <= 0) continue;
957+
const int substep_simd = get_substep_simd(n_i, 0, has_tail);
958+
if (substep_simd <= 0) continue;
902959
const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0);
903-
comp_dot_product(kernel_type, vmm_acc, vmm_b(i));
960+
const bool is_tail_block
961+
= n_i + 1 == n_e && has_tail && substep_simd < simd_w_;
962+
comp_dot_product(
963+
kernel_type, vmm_acc, vmm_b(i), n_i, is_tail_block);
904964
}
905965
}
906966
};

src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
230230
const int simd_w_;
231231
const int max_vmms_;
232232
const bool compute_dst_zp_, compute_src_zp_;
233+
const bool is_src_zp_bcast_;
233234
const bool compute_compensation_; // code-path for either s8s8 or src_zp
234235
const bool has_vpad_; // vertical padding w.r.t. M dimension
235236
const bool has_bpad_; // batch pad is computed for the overlap between the
@@ -341,7 +342,8 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
341342
void load_b(
342343
Vmm vmmb, int n_i, int v_i, bool has_n_tail, bool wei_zp = false);
343344
void comp_dot_product(compute_pad_kernel_t kernel_type, Vmm vmm_acc,
344-
Vmm vmmb); // int8 compensation dot_product (zp and s8s8)
345+
Vmm vmmb, int n,
346+
bool is_tail_block); // int8 compensation dot_product (zp and s8s8)
345347
void pad_comp_kernel(compute_pad_kernel_t kernel_type, int m_blocks,
346348
int n_blocks, int padding, const Xbyak::Reg64 reg_pad,
347349
const std::function<int(int)> &get_mi, bool has_tail = false);
@@ -360,6 +362,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
360362
void apply_post_ops(int m_blocks, int n_blocks, bool has_n_tail);
361363
void maybe_transpose_interleaved_vnni_to_plain(
362364
int m_blocks, int n_blocks, bool has_n_tail);
365+
void load_src_zp();
363366
void compute_int8_compensation(int m_blocks, int n_blocks, bool has_n_tail);
364367
void store_accumulators(int m_blocks, int n_blocks, bool has_n_tail);
365368
void store_accumulators_without_post_ops(

src/cpu/x64/jit_brdgmm_dw_conv.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
255255
const bool params_ok
256256
= IMPLICATION(has_zero_points, utils::one_of(jcp.src_dt, u8, s8))
257257
&& IMPLICATION(jcp.src_zero_point,
258-
attr()->zero_points_.common(DNNL_ARG_SRC))
258+
attr()->zero_points_.common(DNNL_ARG_SRC)
259+
|| attr()->zero_points_.per_dim_1(DNNL_ARG_SRC))
259260
&& IMPLICATION(jcp.dst_zero_point,
260261
attr()->zero_points_.common(DNNL_ARG_DST));
261262
VDISPATCH_CONV(params_ok, VERBOSE_UNSUPPORTED_ZP_CFG);
@@ -583,7 +584,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
583584
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
584585
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
585586

586-
DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
587+
DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
587588
DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
588589

589590
const int wei_scale_mask
@@ -753,8 +754,11 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
753754
post_ops_data.scales = &oscales[jcp.is_oc_scale * ch];
754755
post_ops_data.oc_logical_off = ch;
755756
post_ops_data.dst_scales = dst_scales;
756-
post_ops_data.zp_a_val
757-
= jcp.src_zero_point ? src_zero_point : 1;
757+
const bool is_bcast_zp
758+
= pd()->attr()->zero_points_.common(DNNL_ARG_SRC);
759+
post_ops_data.a_zp_values = jcp.src_zero_point
760+
? src_zero_point + ch * !is_bcast_zp
761+
: nullptr;
758762
post_ops_data.c_zp_values
759763
= jcp.dst_zero_point ? dst_zero_point : nullptr;
760764
post_ops_data.a_zp_compensations

0 commit comments

Comments
 (0)