@@ -5853,9 +5853,12 @@ static inline bool canRelAddr(const RegisterBlock &blockSrc,
5853
5853
}
5854
5854
5855
5855
static inline int block2DWidthAlignment(Type T, const RegisterBlock &block,
5856
+ const MatrixAddressing &atype,
5856
5857
const MatrixAddressingStrategy &astrategy) {
5857
5858
// Block 2D width must be DW-aligned, but generally use QW alignment for better performance for reads.
5858
- return ((astrategy.noExtraPad || block.writable) ? 4 : 8);
5859
+ return ((astrategy.noExtraPad || block.writable || atype.alignment % 8)
5860
+ ? 4
5861
+ : 8);
5859
5862
}
5860
5863
5861
5864
static inline int block2DBaseAlignment(HW hw, int stepping) {
@@ -6077,7 +6080,7 @@ void gemm_kernel_generator_t<hw>::setupAddr(Type T, const GRFRange &addr,
6077
6080
if (doBaseAdjust && !astrategy.address2D) stub();
6078
6081
Subregister baStorage, baseAdjust, baseAdjustElems;
6079
6082
6080
- int widthAlign = block2DWidthAlignment(T, block, astrategy);
6083
+ int widthAlign = block2DWidthAlignment(T, block, atype, astrategy);
6081
6084
6082
6085
if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u);
6083
6086
@@ -6836,6 +6839,7 @@ void gemm_kernel_generator_t<hw>::remaskLayout(Type T, int index, bool column,
6836
6839
}
6837
6840
6838
6841
static bool needsRemask(Type T, bool column, const RegisterBlock &block,
6842
+ const MatrixAddressing &atype,
6839
6843
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
6840
6844
if (!ignoreMasks)
6841
6845
if (column ? !block.remainderC : !block.remainderR) return false;
@@ -6847,19 +6851,20 @@ static bool needsRemask(Type T, bool column, const RegisterBlock &block,
6847
6851
int maskGranularity = block.ebytes;
6848
6852
if (block.ebytes >= 16) maskGranularity = 4;
6849
6853
if (block2DRemask)
6850
- maskGranularity = std::max(
6851
- maskGranularity, block2DWidthAlignment(T, block, astrategy));
6854
+ maskGranularity = std::max(maskGranularity,
6855
+ block2DWidthAlignment(T, block, atype , astrategy));
6852
6856
if (ignoreMasks && !(block2DRemask && astrategy.address2D))
6853
6857
maskGranularity = 256;
6854
6858
6855
6859
return (T.paddedSize() < maskGranularity);
6856
6860
}
6857
6861
6858
6862
static bool needsRemask(Type T, bool column,
6859
- const vector<RegisterBlock> &layout,
6863
+ const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
6860
6864
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
6861
6865
for (auto &block : layout)
6862
- if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true;
6866
+ if (needsRemask(T, column, block, atype, astrategy, ignoreMasks))
6867
+ return true;
6863
6868
return false;
6864
6869
}
6865
6870
@@ -14504,11 +14509,11 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
14504
14509
bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded;
14505
14510
bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded;
14506
14511
slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy
14507
- && needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai_strategy ,
14508
- asIfMaskedAi);
14512
+ && needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai ,
14513
+ state.Ai_strategy, asIfMaskedAi);
14509
14514
slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy
14510
- && needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi_strategy ,
14511
- asIfMaskedBi);
14515
+ && needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi ,
14516
+ state.Bi_strategy, asIfMaskedBi);
14512
14517
}
14513
14518
14514
14519
static inline void kLoopModifiedFlagAP(GEMMState &state) {
@@ -15363,11 +15368,11 @@ void gemm_kernel_generator_t<hw>::kLoop(KLoop type, const GEMMProblem &problem,
15363
15368
15364
15369
// A/B remasking in k dimension, during remainder handling.
15365
15370
bool remaskA = !slmA && readA && (minOPCount > 1)
15366
- && needsRemask(Ta_load, true, state.A_layoutRem, strategy .A,
15367
- state.A_lateKRem);
15371
+ && needsRemask(Ta_load, true, state.A_layoutRem, problem .A,
15372
+ strategy.A, state.A_lateKRem);
15368
15373
bool remaskB = !slmB && readB && (minOPCount > 1)
15369
- && needsRemask(Tb_load, false, state.B_layoutRem, strategy .B,
15370
- state.B_lateKRem);
15374
+ && needsRemask(Tb_load, false, state.B_layoutRem, problem .B,
15375
+ strategy.B, state.B_lateKRem);
15371
15376
15372
15377
if (Ta.isInteger() && Tb.isInteger() && !calcASums && !calcBSums) {
15373
15378
// Only need to remask one operand for integer A/B. Choose the smaller one.
0 commit comments