@@ -767,10 +767,38 @@ void jit_brgemm_kernel_t::read_params() {
767
767
void jit_brgemm_kernel_t::zero_accumulators (int bd_block2, bool is_bdb_tail,
768
768
int ld_block2, bool is_ld_tail, bool skip_accumulation) {
769
769
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block ;
770
+ const bool need_to_apply_beta = brg.beta != 0 .f ;
770
771
for_ (int bd = 0 ; bd < bd_block; bd++)
771
772
for (int ld = 0 ; ld < ld_block2; ld++) {
772
773
auto zmm = accm (ld_block2, bd, ld);
773
- eor (zmm.d , zmm.d , zmm.d );
774
+ // This part is moved here from apply_alpha_beta function so that fadd instruction can be avoided.
775
+ // This is also required only when K is blocked.
776
+ if (need_to_apply_beta) {
777
+ const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
778
+ const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
779
+
780
+ const int offset = C_offset (bd, ld);
781
+
782
+ int base_offset = 0 ;
783
+ auto x_addr = reg_aux_C;
784
+
785
+ if ((unsigned )(offset - base_offset) > cpu_sveLen * 7 ) {
786
+ add_imm (reg_tmp_, reg_aux_C, offset, X_TMP_0);
787
+ base_offset = offset;
788
+ x_addr = reg_tmp_;
789
+ }
790
+ LD_MUL_VL (ld1w, zmm.s , k_mask, x_addr, offset - base_offset, 4 );
791
+
792
+ const bool need_init_beta_vmm = brg.beta != 1 .f ;
793
+ auto vmm_beta = z_tail_mask ();
794
+ if (need_init_beta_vmm) {
795
+ auto wreg_tmp = WReg (reg_tmp_gpr.getIdx ());
796
+ mov_imm (wreg_tmp, float2int (static_cast <float >(brg.beta )));
797
+ dup (vmm_beta.s , wreg_tmp);
798
+ fmul (zmm.s , zmm.s , vmm_beta.s );
799
+ }
800
+ } else
801
+ eor (zmm.d , zmm.d , zmm.d );
774
802
}
775
803
}
776
804
@@ -791,58 +819,6 @@ void jit_brgemm_kernel_t::apply_alpha_beta(
791
819
if (dq2ps_required) { scvtf (vmm.s , P_ALL_ONE / T_m, vmm.s ); }
792
820
if (apply_alpha) { fmul (vmm.s , vmm.s , vmm_alpha.s ); }
793
821
}
794
-
795
- if (brg.beta == 0 .f ) return ;
796
- const bool use_vadd_for_beta = brg.beta == 1 .f && !dq2ps_required;
797
- const bool need_init_beta_vmm = brg.beta != 1 .f ;
798
- auto vmm_prev_dst = z_tmp_1 ();
799
- auto vmm_beta = z_tail_mask ();
800
- if (need_init_beta_vmm) {
801
- auto wreg_tmp = WReg (reg_tmp_gpr.getIdx ());
802
- mov_imm (wreg_tmp, float2int (static_cast <float >(brg.beta )));
803
- dup (vmm_beta.s , wreg_tmp);
804
- }
805
-
806
- int base_offset = 0 ;
807
- auto x_addr = reg_aux_C;
808
- for_ (int bd = 0 ; bd < bd_block; bd++)
809
- for (int ld = 0 ; ld < ld_block2; ld++) {
810
- const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
811
- const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
812
- auto vmm = accm (ld_block2, bd, ld);
813
- if (use_vadd_for_beta) {
814
- if (brg.is_int8 ) {
815
- assert (!" unsupported\n " );
816
- } else {
817
- ZRegS z_masked = vmm.s ;
818
- ZRegS z (vmm.getIdx ());
819
-
820
- const int offset = C_offset (bd, ld);
821
-
822
- if ((unsigned )(offset - base_offset) > cpu_sveLen * 7 ) {
823
- add_imm (reg_tmp_, reg_aux_C, offset, X_TMP_0);
824
- base_offset = offset;
825
- x_addr = reg_tmp_;
826
- }
827
- LD_MUL_VL (ld1w, vmm_prev_dst.s , k_mask, x_addr,
828
- offset - base_offset, 4 );
829
- if (is_ld_tail) {
830
- movprfx (z_masked, k_mask / T_z, z);
831
- fadd (z_masked, k_mask / T_m, vmm_prev_dst.s );
832
- } else {
833
- fadd (z_masked, z_masked, vmm_prev_dst.s );
834
- }
835
- }
836
- } else {
837
- add_imm (X_DEFAULT_ADDR, reg_aux_C, C_offset (bd, ld), X_TMP_0);
838
- ld1w (vmm_prev_dst.s , k_mask / T_z, ptr (X_DEFAULT_ADDR));
839
- if (brg.beta == 1 .f ) {
840
- fadd (vmm.s , vmm.s , vmm_prev_dst.s );
841
- } else {
842
- fmla (vmm.s , P_ALL_ONE / T_m, vmm_prev_dst.s , vmm_beta.s );
843
- }
844
- }
845
- }
846
822
}
847
823
848
824
void jit_brgemm_kernel_t::apply_post_ops (
@@ -1465,7 +1441,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1465
1441
int base_offset = 0 ;
1466
1442
1467
1443
for (int rd = 0 ; rd < rd_loop; rd += brg.rd_step ) {
1468
- int prefetch_count_B = 0 ;
1469
1444
for (int ld = 0 ; ld < ld_block2; ld++) {
1470
1445
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
1471
1446
if (brg.dt_b == data_type::f16) {
@@ -1497,13 +1472,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1497
1472
broadcast (bcst (), A_offset (bd, rd),
1498
1473
have_to_load_bytes && bd_by_load_bytes, brg.dt_a );
1499
1474
}
1500
- if (prefetch_count_B < ld_block2) {
1501
- add_imm (X_DEFAULT_ADDR, reg_aux_B,
1502
- B_offset (prefetch_count_B++, rd)
1503
- + brg.LDB * brg.rd_block * brg.typesize_B ,
1504
- X_TMP_0);
1505
- prfm (PLDL1KEEP, ptr (X_DEFAULT_ADDR));
1506
- }
1475
+ // The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
1507
1476
for (int ld = 0 ; ld < ld_block2; ld++) {
1508
1477
auto zmm = accm (ld_block2, bd, ld);
1509
1478
if (is_emdbd) {
0 commit comments