@@ -766,10 +766,38 @@ void jit_brgemm_kernel_t::read_params() {
766
766
void jit_brgemm_kernel_t::zero_accumulators (int bd_block2, bool is_bdb_tail,
767
767
int ld_block2, bool is_ld_tail, bool skip_accumulation) {
768
768
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block ;
769
+ const bool need_to_apply_beta = brg.beta != 0 .f ;
769
770
for_ (int bd = 0 ; bd < bd_block; bd++)
770
771
for (int ld = 0 ; ld < ld_block2; ld++) {
771
772
auto zmm = accm (ld_block2, bd, ld);
772
- eor (zmm.d , zmm.d , zmm.d );
773
+ // This part is moved here from apply_alpha_beta function so that fadd instruction can be avoided.
774
+ // This is also required only when K is blocked.
775
+ if (need_to_apply_beta) {
776
+ const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
777
+ const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
778
+
779
+ const int offset = C_offset (bd, ld);
780
+
781
+ int base_offset = 0 ;
782
+ auto x_addr = reg_aux_C;
783
+
784
+ if ((unsigned )(offset - base_offset) > cpu_sveLen * 7 ) {
785
+ add_imm (reg_tmp_, reg_aux_C, offset, X_TMP_0);
786
+ base_offset = offset;
787
+ x_addr = reg_tmp_;
788
+ }
789
+ LD_MUL_VL (ld1w, zmm.s , k_mask, x_addr, offset - base_offset, 4 );
790
+
791
+ const bool need_init_beta_vmm = brg.beta != 1 .f ;
792
+ auto vmm_beta = z_tail_mask ();
793
+ if (need_init_beta_vmm) {
794
+ auto wreg_tmp = WReg (reg_tmp_gpr.getIdx ());
795
+ mov_imm (wreg_tmp, float2int (static_cast <float >(brg.beta )));
796
+ dup (vmm_beta.s , wreg_tmp);
797
+ fmul (zmm.s , zmm.s , vmm_beta.s );
798
+ }
799
+ } else
800
+ eor (zmm.d , zmm.d , zmm.d );
773
801
}
774
802
}
775
803
@@ -791,57 +819,7 @@ void jit_brgemm_kernel_t::apply_alpha_beta(
791
819
if (apply_alpha) { fmul (vmm.s , vmm.s , vmm_alpha.s ); }
792
820
}
793
821
794
- if (brg.beta == 0 .f ) return ;
795
- const bool use_vadd_for_beta = brg.beta == 1 .f && !dq2ps_required;
796
- const bool need_init_beta_vmm = brg.beta != 1 .f ;
797
- auto vmm_prev_dst = z_tmp_1 ();
798
- auto vmm_beta = z_tail_mask ();
799
- if (need_init_beta_vmm) {
800
- auto wreg_tmp = WReg (reg_tmp_gpr.getIdx ());
801
- mov_imm (wreg_tmp, float2int (static_cast <float >(brg.beta )));
802
- dup (vmm_beta.s , wreg_tmp);
803
- }
804
-
805
- int base_offset = 0 ;
806
- auto x_addr = reg_aux_C;
807
- for_ (int bd = 0 ; bd < bd_block; bd++)
808
- for (int ld = 0 ; ld < ld_block2; ld++) {
809
- const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
810
- const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
811
- auto vmm = accm (ld_block2, bd, ld);
812
- if (use_vadd_for_beta) {
813
- if (brg.is_int8 ) {
814
- assert (!" unsupported\n " );
815
- } else {
816
- ZRegS z_masked = vmm.s ;
817
- ZRegS z (vmm.getIdx ());
818
-
819
- const int offset = C_offset (bd, ld);
820
-
821
- if ((unsigned )(offset - base_offset) > cpu_sveLen * 7 ) {
822
- add_imm (reg_tmp_, reg_aux_C, offset, X_TMP_0);
823
- base_offset = offset;
824
- x_addr = reg_tmp_;
825
- }
826
- LD_MUL_VL (ld1w, vmm_prev_dst.s , k_mask, x_addr,
827
- offset - base_offset, 4 );
828
- if (is_ld_tail) {
829
- movprfx (z_masked, k_mask / T_z, z);
830
- fadd (z_masked, k_mask / T_m, vmm_prev_dst.s );
831
- } else {
832
- fadd (z_masked, z_masked, vmm_prev_dst.s );
833
- }
834
- }
835
- } else {
836
- add_imm (X_DEFAULT_ADDR, reg_aux_C, C_offset (bd, ld), X_TMP_0);
837
- ld1w (vmm_prev_dst.s , k_mask / T_z, ptr (X_DEFAULT_ADDR));
838
- if (brg.beta == 1 .f ) {
839
- fadd (vmm.s , vmm.s , vmm_prev_dst.s );
840
- } else {
841
- fmla (vmm.s , P_ALL_ONE / T_m, vmm_prev_dst.s , vmm_beta.s );
842
- }
843
- }
844
- }
822
+ // This part is moved to the function zero_accumulators.
845
823
}
846
824
847
825
void jit_brgemm_kernel_t::apply_post_ops (
@@ -1464,7 +1442,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1464
1442
int base_offset = 0 ;
1465
1443
1466
1444
for (int rd = 0 ; rd < rd_loop; rd += brg.rd_step ) {
1467
- int prefetch_count_B = 0 ;
1468
1445
for (int ld = 0 ; ld < ld_block2; ld++) {
1469
1446
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
1470
1447
if (brg.dt_b == data_type::f16) {
@@ -1496,13 +1473,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1496
1473
broadcast (bcst (), A_offset (bd, rd),
1497
1474
have_to_load_bytes && bd_by_load_bytes, brg.dt_a );
1498
1475
}
1499
- if (prefetch_count_B < ld_block2) {
1500
- add_imm (X_DEFAULT_ADDR, reg_aux_B,
1501
- B_offset (prefetch_count_B++, rd)
1502
- + brg.LDB * brg.rd_block * brg.typesize_B ,
1503
- X_TMP_0);
1504
- prfm (PLDL1KEEP, ptr (X_DEFAULT_ADDR));
1505
- }
1476
+ // The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
1506
1477
for (int ld = 0 ; ld < ld_block2; ld++) {
1507
1478
auto zmm = accm (ld_block2, bd, ld);
1508
1479
if (is_emdbd) {
0 commit comments