@@ -588,13 +588,19 @@ void BLASKernelGenerator<hw>::outerProductSystolic(int h, int ha, int hb, int op
588
588
B = state.sysSumAll1s [0 ];
589
589
nb = elementsPerGRF (hw, Tb);
590
590
B_block = &sumBlock;
591
- C = findBlockReg (Tc, state.As_layout , x, 0 , state.As_regs , nc, C_block);
591
+ if (repackC)
592
+ C = findBlockReg (Tc, state.Asr_layout , x % Cr_unrollM, 0 , state.Asr_regs , nc, C_block);
593
+ else
594
+ C = findBlockReg (Tc, state.As_layout , x, 0 , state.As_regs , nc, C_block);
592
595
} else {
593
596
A = state.sysSumAll1s [0 ];
594
597
na = elementsPerGRF (hw, Ta);
595
598
A_block = &sumBlock;
596
599
B = findBlockReg (Tb, B_layout, hhb, x, B_regs, nb, B_block);
597
- C = findBlockReg (Tc, state.Bs_layout , 0 , x, state.Bs_regs , nc, C_block);
600
+ if (repackC)
601
+ C = findBlockReg (Tc, state.Bsr_layout , 0 , x % Cr_unrollN, state.Bsr_regs , nc, C_block);
602
+ else
603
+ C = findBlockReg (Tc, state.Bs_layout , 0 , x, state.Bs_regs , nc, C_block);
598
604
}
599
605
600
606
int nv = globalCM ? na : nb;
@@ -672,6 +678,10 @@ void BLASKernelGenerator<hw>::outerProductRepackC(int x0, int xr0, int nx, int h
672
678
bool globalCM = isLayoutColMajor (C_layout);
673
679
bool scaleA = state.lateScale2DA , scaleB = state.lateScale2DB ;
674
680
681
+ bool sumA = problem.needsASums ();
682
+ bool sumB = problem.needsBSums ();
683
+ if (globalCM ? sumB : sumA) stub ();
684
+
675
685
if (Tc.size () != Tc_compute.size ()) stub ();
676
686
if (state.C_buffers > 1 ) stub ();
677
687
@@ -712,41 +722,56 @@ void BLASKernelGenerator<hw>::outerProductRepackC(int x0, int xr0, int nx, int h
712
722
for (int x1 = 0 ; x1 < nx; x1 += 2 * nec) {
713
723
int x = x0 + x1, xr = xr0 + x1;
714
724
int xchunk = std::min (nx - x1, 2 * nec);
715
- for (int y = 0 ; y < ny; y++) {
725
+ for (int y = 0 ; y < ny + sumA + sumB ; y++) {
716
726
auto i = globalCM ? x : y;
717
727
auto j = globalCM ? y : x;
718
728
auto ir = globalCM ? xr : y;
719
729
auto jr = globalCM ? y : xr;
720
730
721
- int ne, ner, nes[2 ];
722
- const RegisterBlock *C_block, *Cr_block, *sblock;
723
- auto C = findBlockReg (Tc, C_layout, i, j, C_regs, ne, C_block);
724
- auto Cr = findBlockReg (Tc_compute, Cr_layout, ir, jr, Cr_regs, ner, Cr_block);
731
+ int ne = 0 , ner = 0 , nes[2 ] = {0 , 0 };
732
+ const RegisterBlock *C_block = nullptr , *Cr_block = nullptr ;
733
+ const RegisterBlock *sblock = nullptr ;
734
+ Subregister C, Cr;
735
+
736
+ bool doASum = sumA && y == ny;
737
+ bool doBSum = sumB && y == ny;
738
+
739
+ if (y < ny) {
740
+ C = findBlockReg (Tc, C_layout, i, j, C_regs, ne, C_block);
741
+ Cr = findBlockReg (Tc_compute, Cr_layout, ir, jr, Cr_regs, ner, Cr_block);
742
+ } else if (doASum) {
743
+ C = findBlockReg (Tc, state.As_layout , x, 0 , state.As_regs , ne, C_block);
744
+ Cr = findBlockReg (Tc_compute, state.Asr_layout , xr, 0 , state.Asr_regs , ner, Cr_block);
745
+ } else if (doBSum) {
746
+ C = findBlockReg (Tc, state.Bs_layout , 0 , x, state.Bs_regs , ne, C_block);
747
+ Cr = findBlockReg (Tc_compute, state.Bsr_layout , 0 , xr, state.Bsr_regs , ner, Cr_block);
748
+ }
725
749
726
750
std::array<Subregister, 2 > scale;
727
751
std::array<int , 2 > scaleStride = {0 , 0 };
728
752
int nscale = 0 ;
729
- if (scaleA) {
753
+ if (scaleA && !doBSum ) {
730
754
int js = ((jr + h) / problem.aqGroupK ) % state.kaqLate ;
731
755
scale[nscale] = findBlockReg (state.Ta_scaleInt , state.Ar_scaleLayout ,
732
756
i, js, state.Ar_scaleRegs , nes[0 ], sblock);
733
757
scaleStride[nscale] = globalCM ? 1 : 0 ;
734
758
nscale++;
735
759
}
736
- if (scaleB) {
760
+ if (scaleB && !doASum ) {
737
761
int is = ((ir + h) / problem.bqGroupK ) % state.kbqLate ;
738
762
scale[nscale] = findBlockReg (state.Tb_scaleInt , state.Br_scaleLayout ,
739
763
is, j, state.Br_scaleRegs , nes[1 ], sblock);
740
764
scaleStride[nscale] = globalCM ? 0 : 1 ;
741
765
nscale++;
742
766
}
743
767
744
- ne = std::min (ne, ner);
768
+ ne = std::min ({ ne, ner, xchunk} );
745
769
if (scaleStride[0 ] == 1 ) ne = std::min (ne, nes[0 ]);
746
770
if (scaleStride[1 ] == 1 ) ne = std::min (ne, nes[1 ]);
747
771
748
772
if (ne < xchunk) stub ();
749
- if (C_block->crosspack != 1 || Cr_block->crosspack != 1 ) stub ();
773
+ if ((C_block && C_block->crosspack != 1 )
774
+ || (Cr_block && Cr_block->crosspack != 1 )) stub ();
750
775
751
776
WorkItem item = {C, Cr, ne, iacc, scale, scaleStride};
752
777
bool coalesce = false ;
0 commit comments