@@ -45,6 +45,7 @@ jit_brdgmm_kernel_base_t<Wmm>::jit_brdgmm_kernel_base_t(
45
45
, max_vmms_(isa_num_vregs(brg.isa_impl))
46
46
, compute_dst_zp_(brg.zp_type_c != brgemm_broadcast_t ::none)
47
47
, compute_src_zp_(brg.zp_type_a != brgemm_broadcast_t ::none)
48
+ , is_src_zp_bcast_(brg.zp_type_a == brgemm_broadcast_t ::per_tensor)
48
49
, compute_compensation_(compute_src_zp_ || brg.req_s8s8_compensation)
49
50
, has_vpad_(brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0 )
50
51
, has_bpad_(brg.brgattr.max_top_bpad > 0 || brg.brgattr.max_bottom_bpad > 0 )
@@ -147,7 +148,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::read_params() {
147
148
}
148
149
149
150
if (compute_src_zp_) {
150
- mov (reg_tmp, ptr[param1 + GET_OFF (zp_a_val )]);
151
+ mov (reg_tmp, ptr[param1 + GET_OFF (a_zp_values )]);
151
152
mov (ptr[rsp + src_zp_value_], reg_tmp);
152
153
153
154
mov (reg_tmp, ptr[param1 + GET_OFF (a_zp_compensations)]);
@@ -609,6 +610,17 @@ void jit_brdgmm_kernel_base_t<Wmm>::maybe_transpose_interleaved_vnni_to_plain(
609
610
}
610
611
}
611
612
613
+ template <typename Wmm>
614
+ void jit_brdgmm_kernel_base_t <Wmm>::load_src_zp() {
615
+ mov (reg_src_zero_point, ptr[rsp + src_zp_value_]);
616
+ lea (reg_src_zero_point,
617
+ is_src_zp_bcast_
618
+ ? ptr_b[reg_src_zero_point]
619
+ : ptr[reg_src_zero_point + reg_aux_N * sizeof (int32_t )]);
620
+ if (!is_superset (brg.isa_impl , avx512_core) && is_src_zp_bcast_)
621
+ uni_vpbroadcastd (vmm_bcast (), ptr[reg_src_zero_point]);
622
+ }
623
+
612
624
template <typename Wmm>
613
625
void jit_brdgmm_kernel_base_t <Wmm>::compute_int8_compensation(
614
626
int m_blocks, int n_blocks, bool has_n_tail) {
@@ -620,12 +632,10 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
620
632
lea (reg_s8s8_comp, ptr[reg_s8s8_comp + reg_aux_N * sizeof (int32_t )]);
621
633
}
622
634
if (compute_src_zp_) {
623
- lea (reg_src_zero_point, ptr[rsp + src_zp_value_] );
635
+ load_src_zp ( );
624
636
mov (reg_zp_compensation, ptr[rsp + zp_compensation_]);
625
637
lea (reg_zp_compensation,
626
638
ptr[reg_zp_compensation + reg_aux_N * sizeof (int32_t )]);
627
- if (!is_superset (brg.isa_impl , avx512_core))
628
- uni_vpbroadcastd (vmm_bcast (), ptr[reg_src_zero_point]);
629
639
}
630
640
631
641
for_ (int v_i = 0 ; v_i < v_substep; ++v_i)
@@ -640,16 +650,35 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
640
650
}
641
651
if (compute_src_zp_) {
642
652
// zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
643
- const Vmm vmm_zp = vmm_zp_comp ();
644
- vmovups (vmm_zp,
645
- maybe_EVEX_compress_addr (reg_zp_compensation, offset));
646
- if (is_superset (brg.isa_impl , avx512_core)) {
647
- const bool src_zp_is_common = true ;
648
- vpmulld (vmm_zp, vmm_zp,
649
- maybe_EVEX_compress_addr (
650
- reg_src_zero_point, 0 , src_zp_is_common));
653
+ const bool is_tail
654
+ = n + 1 == n_blocks && has_n_tail && substep_simd < simd_w_;
655
+ const Vmm vmm_zp = isa_has_masks (brg.isa_impl )
656
+ ? maybe_mask (vmm_zp_comp (), is_tail, false )
657
+ : vmm_zp_comp ();
658
+ if (IMPLICATION (is_tail, isa_has_masks (brg.isa_impl ))) {
659
+ vmovups (vmm_zp,
660
+ maybe_EVEX_compress_addr (reg_zp_compensation, offset));
661
+ if (is_src_zp_bcast_) {
662
+ if (is_superset (brg.isa_impl , avx512_core))
663
+ vpmulld (vmm_zp, vmm_zp,
664
+ maybe_EVEX_compress_addr (
665
+ reg_src_zero_point, 0 , true ));
666
+ else
667
+ vpmulld (vmm_zp, vmm_zp, vmm_bcast ());
668
+ } else
669
+ vpmulld (vmm_zp, vmm_zp,
670
+ maybe_EVEX_compress_addr (
671
+ reg_src_zero_point, offset));
651
672
} else {
652
- vpmulld (vmm_zp, vmm_zp, vmm_bcast ());
673
+ const int tail_size = tail_length ();
674
+ const Vmm ymm_tmp
675
+ = vmm_bcast (); // used for bcast or tail processing in avx2
676
+ load_data (data_type::s32, vmm_zp,
677
+ ptr[reg_zp_compensation + offset], tail_size);
678
+ if (!is_src_zp_bcast_)
679
+ load_data (data_type::s32, ymm_tmp,
680
+ ptr[reg_src_zero_point + offset], tail_size);
681
+ vpmulld (vmm_zp, vmm_zp, ymm_tmp);
653
682
}
654
683
}
655
684
for (int m = 0 ; m < m_blocks; m++) {
@@ -795,24 +824,48 @@ void jit_brdgmm_kernel_base_t<Wmm>::load_b(
795
824
796
825
template <typename Wmm>
797
826
void jit_brdgmm_kernel_base_t <Wmm>::comp_dot_product(
798
- compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb) {
827
+ compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb, int n,
828
+ bool is_tail_block) {
799
829
switch (kernel_type) {
800
830
case compute_pad_kernel_t ::s8s8_kernel:
801
831
vpdpbusd (vmm_acc, vmm_shift (), vmmb,
802
832
is_superset (brg.isa_impl , avx512_core)
803
833
? Xbyak::EvexEncoding
804
834
: Xbyak::VexEncoding);
805
835
break ;
806
- case compute_pad_kernel_t ::zero_point_kernel:
807
- if (is_superset (brg.isa_impl , avx512_core)) {
808
- vpmulld (vmm_zp_comp (), vmmb,
809
- maybe_EVEX_compress_addr (reg_src_zero_point, 0 , true ));
836
+ case compute_pad_kernel_t ::zero_point_kernel: {
837
+ const Vmm vmm_zp = isa_has_masks (brg.isa_impl )
838
+ ? maybe_mask (vmm_zp_comp (), is_tail_block, false )
839
+ : vmm_zp_comp ();
840
+ const size_t offset = comp_offset (n);
841
+ if (IMPLICATION (is_tail_block, isa_has_masks (brg.isa_impl ))) {
842
+ if (is_src_zp_bcast_) {
843
+ if (is_superset (brg.isa_impl , avx512_core))
844
+ vpmulld (vmm_zp, vmmb,
845
+ maybe_EVEX_compress_addr (
846
+ reg_src_zero_point, 0 , true ));
847
+ else
848
+ vpmulld (vmm_zp, vmmb, vmm_bcast ());
849
+ } else {
850
+ const Xbyak::Address src_zp_addr = maybe_EVEX_compress_addr (
851
+ reg_src_zero_point, offset);
852
+ if (is_fast_vnni_int8 ()) {
853
+ vmovups (vmm_zp, src_zp_addr);
854
+ vpermd (vmm_zp, vmm_permute (), vmm_zp);
855
+ vpmulld (vmm_zp, vmmb, vmm_zp);
856
+ } else
857
+ vpmulld (vmm_zp, vmmb, src_zp_addr);
858
+ }
810
859
} else {
811
- uni_vpbroadcastd (vmm_bcast (), ptr[reg_src_zero_point]);
812
- vpmulld (vmm_zp_comp (), vmmb, vmm_bcast ());
860
+ const Vmm ymm_tmp
861
+ = vmm_bcast (); // used for bcast or tail processing in avx2
862
+ if (!is_src_zp_bcast_)
863
+ load_data (data_type::s32, ymm_tmp,
864
+ ptr[reg_src_zero_point + offset], tail_length ());
865
+ vpmulld (vmm_zp, vmmb, ymm_tmp);
813
866
}
814
867
vpaddd (vmm_acc, vmm_acc, vmm_zp_comp ());
815
- break ;
868
+ } break ;
816
869
default : assert (!" unsupported comp_kernel type" );
817
870
}
818
871
}
@@ -853,21 +906,25 @@ void jit_brdgmm_kernel_base_t<Wmm>::pad_comp_kernel(
853
906
854
907
for (int pad_i = max_m_unroll; pad_i > 0 ; --pad_i) {
855
908
L (jmp_table_labels[pad_i]);
856
- if (is_zero_point_kernel)
857
- lea (reg_src_zero_point, ptr[rsp + src_zp_value_]);
909
+ if (is_zero_point_kernel) load_src_zp ();
858
910
if (pad_i > m_blocks) continue ;
859
911
const int m_i = get_mi (pad_i);
860
912
int p_b_i = 0 ;
861
913
for (int n_i = 0 ; n_i < n_blocks; ++n_i, ++p_b_i) {
862
- if (get_substep_simd (n_i, 0 , has_tail) <= 0 ) continue ;
914
+ const int substep_simd = get_substep_simd (n_i, 0 , has_tail);
915
+ if (substep_simd <= 0 ) continue ;
863
916
const Vmm vmm_acc = accm (m_blocks, n_blocks, m_i, n_i, 0 );
917
+ const bool is_tail_block
918
+ = n_i + 1 == n_blocks && has_tail && substep_simd < simd_w_;
864
919
if (p_b_i < n_preload_b_vmms) {
865
- comp_dot_product (kernel_type, vmm_acc, vmm_b (p_b_i));
920
+ comp_dot_product (
921
+ kernel_type, vmm_acc, vmm_b (p_b_i), n_i, is_tail_block);
866
922
} else {
867
923
// preloaded vmm_b not available
868
924
const Vmm vmm_wei = vmm_b (max_bvmms - 1 );
869
925
load_b (vmm_wei, n_i, 0 , has_tail, load_broadcast_wei);
870
- comp_dot_product (kernel_type, vmm_acc, vmm_wei);
926
+ comp_dot_product (
927
+ kernel_type, vmm_acc, vmm_wei, n_i, is_tail_block);
871
928
}
872
929
}
873
930
}
@@ -885,8 +942,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
885
942
auto kernel_body = [&](compute_pad_kernel_t kernel_type) {
886
943
const bool is_zero_point_kernel
887
944
= kernel_type == compute_pad_kernel_t ::zero_point_kernel;
888
- if (is_zero_point_kernel)
889
- lea (reg_src_zero_point, ptr[rsp + src_zp_value_]);
945
+ if (is_zero_point_kernel) load_src_zp ();
890
946
for (int nb_i = 0 ; nb_i < n_blocks; nb_i += max_bvmms) {
891
947
const int n_e = nstl::min (nb_i + max_bvmms, n_blocks) - nb_i;
892
948
for (int i = 0 ; i < n_e; ++i) {
@@ -898,9 +954,13 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
898
954
for_ (int m_i = 0 ; m_i < m_blocks; ++m_i)
899
955
for (int i = 0 ; i < n_e; ++i) {
900
956
const int n_i = nb_i + i;
901
- if (get_substep_simd (n_i, 0 , has_tail) <= 0 ) continue ;
957
+ const int substep_simd = get_substep_simd (n_i, 0 , has_tail);
958
+ if (substep_simd <= 0 ) continue ;
902
959
const Vmm vmm_acc = accm (m_blocks, n_blocks, m_i, n_i, 0 );
903
- comp_dot_product (kernel_type, vmm_acc, vmm_b (i));
960
+ const bool is_tail_block
961
+ = n_i + 1 == n_e && has_tail && substep_simd < simd_w_;
962
+ comp_dot_product (
963
+ kernel_type, vmm_acc, vmm_b (i), n_i, is_tail_block);
904
964
}
905
965
}
906
966
};
0 commit comments