Skip to content

Commit c6bef5f

Browse files
committed
src: gpu: intel: jit: gemm: limited a/b sum support for c repack
1 parent 30f0fb7 commit c6bef5f

File tree

4 files changed

+48
-11
lines changed

4 files changed

+48
-11
lines changed

src/gpu/intel/jit/gemm/generator/pieces/gemm_setup.cxx

+8
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,10 @@ bool BLASKernelGenerator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStr
15251525
state.repackA ? state.Ar_layout :
15261526
state.A_layout;
15271527
makeSumLayout(false, Ta, As_srcLayout, Tc, state.As_layout, strategy, state);
1528+
if (Tc != Tc_compute) {
1529+
std::swap(state.Asr_layout, state.As_layout); /* TODO: trim down */
1530+
makeUnbackedRegLayout(Tc_compute, state.As_layout, unrollM, 1, true, 1);
1531+
}
15281532
if (state.systolicSumA)
15291533
setupTeardownAccumulateSumSystolic(true, Tb, problem, strategy, state);
15301534
}
@@ -1538,6 +1542,10 @@ bool BLASKernelGenerator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStr
15381542
state.repackB ? state.Br_layout :
15391543
state.B_layout;
15401544
makeSumLayout(true, Tb, Bs_srcLayout, Tc, state.Bs_layout, strategy, state);
1545+
if (Tc != Tc_compute) {
1546+
std::swap(state.Bsr_layout, state.Bs_layout);
1547+
makeUnbackedRegLayout(Tc_compute, state.Bs_layout, 1, unrollN, false, 1);
1548+
}
15411549
if (state.systolicSumB)
15421550
setupTeardownAccumulateSumSystolic(true, Ta, problem, strategy, state);
15431551
}

src/gpu/intel/jit/gemm/generator/pieces/matrix_multiply.cxx

+36-11
Original file line numberDiff line numberDiff line change
@@ -588,13 +588,19 @@ void BLASKernelGenerator<hw>::outerProductSystolic(int h, int ha, int hb, int op
588588
B = state.sysSumAll1s[0];
589589
nb = elementsPerGRF(hw, Tb);
590590
B_block = &sumBlock;
591-
C = findBlockReg(Tc, state.As_layout, x, 0, state.As_regs, nc, C_block);
591+
if (repackC)
592+
C = findBlockReg(Tc, state.Asr_layout, x % Cr_unrollM, 0, state.Asr_regs, nc, C_block);
593+
else
594+
C = findBlockReg(Tc, state.As_layout, x, 0, state.As_regs, nc, C_block);
592595
} else {
593596
A = state.sysSumAll1s[0];
594597
na = elementsPerGRF(hw, Ta);
595598
A_block = &sumBlock;
596599
B = findBlockReg(Tb, B_layout, hhb, x, B_regs, nb, B_block);
597-
C = findBlockReg(Tc, state.Bs_layout, 0, x, state.Bs_regs, nc, C_block);
600+
if (repackC)
601+
C = findBlockReg(Tc, state.Bsr_layout, 0, x % Cr_unrollN, state.Bsr_regs, nc, C_block);
602+
else
603+
C = findBlockReg(Tc, state.Bs_layout, 0, x, state.Bs_regs, nc, C_block);
598604
}
599605

600606
int nv = globalCM ? na : nb;
@@ -672,6 +678,10 @@ void BLASKernelGenerator<hw>::outerProductRepackC(int x0, int xr0, int nx, int h
672678
bool globalCM = isLayoutColMajor(C_layout);
673679
bool scaleA = state.lateScale2DA, scaleB = state.lateScale2DB;
674680

681+
bool sumA = problem.needsASums();
682+
bool sumB = problem.needsBSums();
683+
if (globalCM ? sumB : sumA) stub();
684+
675685
if (Tc.size() != Tc_compute.size()) stub();
676686
if (state.C_buffers > 1) stub();
677687

@@ -712,41 +722,56 @@ void BLASKernelGenerator<hw>::outerProductRepackC(int x0, int xr0, int nx, int h
712722
for (int x1 = 0; x1 < nx; x1 += 2 * nec) {
713723
int x = x0 + x1, xr = xr0 + x1;
714724
int xchunk = std::min(nx - x1, 2 * nec);
715-
for (int y = 0; y < ny; y++) {
725+
for (int y = 0; y < ny + sumA + sumB; y++) {
716726
auto i = globalCM ? x : y;
717727
auto j = globalCM ? y : x;
718728
auto ir = globalCM ? xr : y;
719729
auto jr = globalCM ? y : xr;
720730

721-
int ne, ner, nes[2];
722-
const RegisterBlock *C_block, *Cr_block, *sblock;
723-
auto C = findBlockReg(Tc, C_layout, i, j, C_regs, ne, C_block);
724-
auto Cr = findBlockReg(Tc_compute, Cr_layout, ir, jr, Cr_regs, ner, Cr_block);
731+
int ne = 0, ner = 0, nes[2] = {0, 0};
732+
const RegisterBlock *C_block = nullptr, *Cr_block = nullptr;
733+
const RegisterBlock *sblock = nullptr;
734+
Subregister C, Cr;
735+
736+
bool doASum = sumA && y == ny;
737+
bool doBSum = sumB && y == ny;
738+
739+
if (y < ny) {
740+
C = findBlockReg(Tc, C_layout, i, j, C_regs, ne, C_block);
741+
Cr = findBlockReg(Tc_compute, Cr_layout, ir, jr, Cr_regs, ner, Cr_block);
742+
} else if (doASum) {
743+
C = findBlockReg(Tc, state.As_layout, x, 0, state.As_regs, ne, C_block);
744+
Cr = findBlockReg(Tc_compute, state.Asr_layout, xr, 0, state.Asr_regs, ner, Cr_block);
745+
} else if (doBSum) {
746+
C = findBlockReg(Tc, state.Bs_layout, 0, x, state.Bs_regs, ne, C_block);
747+
Cr = findBlockReg(Tc_compute, state.Bsr_layout, 0, xr, state.Bsr_regs, ner, Cr_block);
748+
}
725749

726750
std::array<Subregister, 2> scale;
727751
std::array<int, 2> scaleStride = {0, 0};
728752
int nscale = 0;
729-
if (scaleA) {
753+
if (scaleA && !doBSum) {
730754
int js = ((jr + h) / problem.aqGroupK) % state.kaqLate;
731755
scale[nscale] = findBlockReg(state.Ta_scaleInt, state.Ar_scaleLayout,
732756
i, js, state.Ar_scaleRegs, nes[0], sblock);
733757
scaleStride[nscale] = globalCM ? 1 : 0;
734758
nscale++;
735759
}
736-
if (scaleB) {
760+
if (scaleB && !doASum) {
737761
int is = ((ir + h) / problem.bqGroupK) % state.kbqLate;
738762
scale[nscale] = findBlockReg(state.Tb_scaleInt, state.Br_scaleLayout,
739763
is, j, state.Br_scaleRegs, nes[1], sblock);
740764
scaleStride[nscale] = globalCM ? 0 : 1;
741765
nscale++;
742766
}
743767

744-
ne = std::min(ne, ner);
768+
ne = std::min({ne, ner, xchunk});
745769
if (scaleStride[0] == 1) ne = std::min(ne, nes[0]);
746770
if (scaleStride[1] == 1) ne = std::min(ne, nes[1]);
747771

748772
if (ne < xchunk) stub();
749-
if (C_block->crosspack != 1 || Cr_block->crosspack != 1) stub();
773+
if ((C_block && C_block->crosspack != 1)
774+
|| (Cr_block && Cr_block->crosspack != 1)) stub();
750775

751776
WorkItem item = {C, Cr, ne, iacc, scale, scaleStride};
752777
bool coalesce = false;

src/gpu/intel/jit/gemm/generator/pieces/register_allocation.cxx

+2
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,8 @@ void BLASKernelGenerator<hw>::gemmAllocRegs(GEMMProblem &problem, GEMMStrategy &
613613
state.Bi_regs[q] = state.ra.alloc_range(state.Bi_regCount);
614614

615615
// Allocate registers for A/B sums.
616+
state.Asr_regs = state.ra.alloc_range(getRegCount(state.Asr_layout));
617+
state.Bsr_regs = state.ra.alloc_range(getRegCount(state.Bsr_layout));
616618
state.As_regs = state.ra.alloc_range(getRegCount(state.As_layout));
617619
state.Bs_regs = state.ra.alloc_range(getRegCount(state.Bs_layout));
618620

src/gpu/intel/jit/gemm/generator/pieces/state.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ struct GEMMState : public CommonState {
261261
GRFMultirange Ao_regs, Bo_regs; // Outgoing data to copy to SLM.
262262
GRFMultirange Ao_regsRem, Bo_regsRem;
263263
GRFMultirange As_regs, Bs_regs; // A row sums/B column sums.
264+
GRFMultirange Asr_regs, Bsr_regs; // A row sums/B column sums to be repacked.
264265
GRFMultirange Ap_regs, Bp_regs, Cp_regs; // A/B/C prefetch registers.
265266
GRFMultirange A_offsetRegs, B_offsetRegs; // A/B offsets (grouped).
266267
GRFMultirange A_scaleRegs, B_scaleRegs; // A/B scales (grouped).
@@ -327,6 +328,7 @@ struct GEMMState : public CommonState {
327328
std::vector<RegisterBlock> Ai_layoutRem, Bi_layoutRem;
328329
std::vector<RegisterBlock> Ao_layout, Bo_layout;
329330
std::vector<RegisterBlock> As_layout, Bs_layout;
331+
std::vector<RegisterBlock> Asr_layout, Bsr_layout;
330332
std::vector<RegisterBlock> Ap_layout, Bp_layout, Cp_layout;
331333
std::vector<RegisterBlock> Ap_layoutAlt, Bp_layoutAlt;
332334
std::vector<RegisterBlock> A_offsetLayout, B_offsetLayout;

0 commit comments

Comments
 (0)