Skip to content

Commit 162d3f3

Browse files
committed
gpu: jit: gemm: only QW-align widths for QW-aligned data
1 parent 56c04f7 commit 162d3f3

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

src/gpu/intel/jit/gemm/gen_gemm_kernel_generator.cpp

+19-14
Original file line numberDiff line numberDiff line change
@@ -5853,9 +5853,12 @@ static inline bool canRelAddr(const RegisterBlock &blockSrc,
58535853
}
58545854

58555855
static inline int block2DWidthAlignment(Type T, const RegisterBlock &block,
5856+
const MatrixAddressing &atype,
58565857
const MatrixAddressingStrategy &astrategy) {
58575858
// Block 2D width must be DW-aligned, but generally use QW alignment for better performance for reads.
5858-
return ((astrategy.noExtraPad || block.writable) ? 4 : 8);
5859+
return ((astrategy.noExtraPad || block.writable || atype.alignment % 8)
5860+
? 4
5861+
: 8);
58595862
}
58605863

58615864
static inline int block2DBaseAlignment(HW hw, int stepping) {
@@ -6077,7 +6080,7 @@ void gemm_kernel_generator_t<hw>::setupAddr(Type T, const GRFRange &addr,
60776080
if (doBaseAdjust && !astrategy.address2D) stub();
60786081
Subregister baStorage, baseAdjust, baseAdjustElems;
60796082

6080-
int widthAlign = block2DWidthAlignment(T, block, astrategy);
6083+
int widthAlign = block2DWidthAlignment(T, block, atype, astrategy);
60816084

60826085
if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u);
60836086

@@ -6836,6 +6839,7 @@ void gemm_kernel_generator_t<hw>::remaskLayout(Type T, int index, bool column,
68366839
}
68376840

68386841
static bool needsRemask(Type T, bool column, const RegisterBlock &block,
6842+
const MatrixAddressing &atype,
68396843
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
68406844
if (!ignoreMasks)
68416845
if (column ? !block.remainderC : !block.remainderR) return false;
@@ -6847,19 +6851,20 @@ static bool needsRemask(Type T, bool column, const RegisterBlock &block,
68476851
int maskGranularity = block.ebytes;
68486852
if (block.ebytes >= 16) maskGranularity = 4;
68496853
if (block2DRemask)
6850-
maskGranularity = std::max(
6851-
maskGranularity, block2DWidthAlignment(T, block, astrategy));
6854+
maskGranularity = std::max(maskGranularity,
6855+
block2DWidthAlignment(T, block, atype, astrategy));
68526856
if (ignoreMasks && !(block2DRemask && astrategy.address2D))
68536857
maskGranularity = 256;
68546858

68556859
return (T.paddedSize() < maskGranularity);
68566860
}
68576861

68586862
static bool needsRemask(Type T, bool column,
6859-
const vector<RegisterBlock> &layout,
6863+
const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
68606864
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
68616865
for (auto &block : layout)
6862-
if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true;
6866+
if (needsRemask(T, column, block, atype, astrategy, ignoreMasks))
6867+
return true;
68636868
return false;
68646869
}
68656870

@@ -14504,11 +14509,11 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
1450414509
bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded;
1450514510
bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded;
1450614511
slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy
14507-
&& needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai_strategy,
14508-
asIfMaskedAi);
14512+
&& needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai,
14513+
state.Ai_strategy, asIfMaskedAi);
1450914514
slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy
14510-
&& needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi_strategy,
14511-
asIfMaskedBi);
14515+
&& needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi,
14516+
state.Bi_strategy, asIfMaskedBi);
1451214517
}
1451314518

1451414519
static inline void kLoopModifiedFlagAP(GEMMState &state) {
@@ -15363,11 +15368,11 @@ void gemm_kernel_generator_t<hw>::kLoop(KLoop type, const GEMMProblem &problem,
1536315368

1536415369
// A/B remasking in k dimension, during remainder handling.
1536515370
bool remaskA = !slmA && readA && (minOPCount > 1)
15366-
&& needsRemask(Ta_load, true, state.A_layoutRem, strategy.A,
15367-
state.A_lateKRem);
15371+
&& needsRemask(Ta_load, true, state.A_layoutRem, problem.A,
15372+
strategy.A, state.A_lateKRem);
1536815373
bool remaskB = !slmB && readB && (minOPCount > 1)
15369-
&& needsRemask(Tb_load, false, state.B_layoutRem, strategy.B,
15370-
state.B_lateKRem);
15374+
&& needsRemask(Tb_load, false, state.B_layoutRem, problem.B,
15375+
strategy.B, state.B_lateKRem);
1537115376

1537215377
if (Ta.isInteger() && Tb.isInteger() && !calcASums && !calcBSums) {
1537315378
// Only need to remask one operand for integer A/B. Choose the smaller one.

0 commit comments

Comments
 (0)