Skip to content

Commit 280bd28

Browse files
petercadkarturov
authored andcommitted
Masks: restrict rdivide field to powers of 2
1 parent 05d68df commit 280bd28

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

src/gpu/intel/jit/gemm/generator/pieces/layout_setup.cxx

+11-10
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
262262
vymask.bitRep = consecutive;
263263
vymask.maskRep = 1;
264264
vymask.rsize = *yblock;
265-
vymask.rdivide = 1;
265+
vymask.rshift = 0;
266266
} else if (logicalSlots < slots) {
267267
auto &fymask = block.colMajor ? block.rowMask.fixed : block.colMask.fixed;
268268
fymask.isFixed = true;
@@ -279,7 +279,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
279279
vxmask.bitRep = (block.simdSize > 16) ? 32 : 16;
280280
vxmask.maskRep = 1;
281281
vxmask.rsize = 1;
282-
vxmask.rdivide = 1;
282+
vxmask.rshift = 0;
283283
} else if (allowDesc && (channelScattered || astrategy.newDP) && *xblock > 1 && !byte) {
284284
fragment = std::min(*xblock, 4 * width / T);
285285
if (block.colMajor) // Clang can't handle the ternary operator equivalent of this.
@@ -482,7 +482,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
482482
vrmask.rsize = rblock;
483483
vrmask.bitRep = std::max<int>(T.paddedSize() / maskGranularity, 1);
484484
vrmask.maskRep = cblock;
485-
vrmask.rdivide = std::max<int>(maskGranularity / T, 1);
485+
vrmask.rshift = ilog2(std::max<int>(maskGranularity / T, 1));
486486
}
487487
} else {
488488
if (avoidFragment) {
@@ -491,8 +491,8 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
491491
vrmask.isFixed = false;
492492
vrmask.bitRep = 0; /* will be filled in later */
493493
vrmask.maskRep = 1;
494-
vrmask.rdivide = 1;
495494
vrmask.rsize = 1;
495+
vrmask.rshift = 0;
496496
} else {
497497
// Fragment it. Could actually handle rowFragment = 2 by changing descriptor.
498498
block.rowFragment = 1;
@@ -520,7 +520,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
520520
vcmask.rsize = cblock;
521521
vcmask.bitRep = std::max<int>(T.paddedSize() / maskGranularity, 1);
522522
vcmask.maskRep = rblock;
523-
vcmask.rdivide = std::max<int>(maskGranularity / T, 1);
523+
vcmask.rshift = ilog2(std::max<int>(maskGranularity / T, 1));
524524
}
525525
} else {
526526
if (avoidFragment) {
@@ -529,8 +529,8 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
529529
vcmask.isFixed = false;
530530
vcmask.bitRep = 0;
531531
vcmask.maskRep = 1;
532-
vcmask.rdivide = 1;
533532
vcmask.rsize = 1;
533+
vcmask.rshift = 0;
534534
} else {
535535
// Fragment it. Could actually handle colFragment = 2 by changing descriptor.
536536
block.colFragment = 1;
@@ -719,7 +719,8 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
719719
auto &vxmask = block.colMajor ? block.rowMask.variable : block.colMask.variable;
720720
vxmask.isFixed = false;
721721
vxmask.bitRep = block.simdSize;
722-
vxmask.maskRep = vxmask.rdivide = vxmask.rsize = 1;
722+
vxmask.maskRep = vxmask.rsize = 1;
723+
vxmask.rshift = 0;
723724
}
724725

725726
if (remainderY) {
@@ -728,7 +729,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
728729
vymask.bitRep = xCacheLines;
729730
vymask.maskRep = 1;
730731
vymask.rsize = yblock;
731-
vymask.rdivide = 1;
732+
vymask.rshift = 0;
732733
}
733734
break;
734735
}
@@ -739,13 +740,13 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
739740
if (block.rowMask && !block.rowMask.fixed.isFixed) {
740741
if (vrmask.rsize == 0)
741742
vrmask.rsize = rblock;
742-
vrmask.maskRep = std::min<int>(vrmask.maskRep, std::max<int>(1, vrmask.rdivide * block.simdSize / (vrmask.bitRep * vrmask.rsize)));
743+
vrmask.maskRep = std::min<int>(vrmask.maskRep, std::max<int>(1, (block.simdSize << vrmask.rshift) / (vrmask.bitRep * vrmask.rsize)));
743744
block.noRowsOK = true; // All-zero masks are always OK.
744745
}
745746
if (block.colMask && !block.colMask.fixed.isFixed) {
746747
if (vcmask.rsize == 0)
747748
vcmask.rsize = cblock;
748-
vcmask.maskRep = std::min<int>(vcmask.maskRep, std::max<int>(1, vcmask.rdivide * block.simdSize / (vcmask.bitRep * vcmask.rsize)));
749+
vcmask.maskRep = std::min<int>(vcmask.maskRep, std::max<int>(1, (block.simdSize << vcmask.rshift) / (vcmask.bitRep * vcmask.rsize)));
749750
block.noColsOK = true;
750751
}
751752

src/gpu/intel/jit/gemm/generator/pieces/masks.cxx

+7-7
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
127127
// Load a variable mask, which requires some minor bit-twiddling.
128128
auto &vmask = assignment.mask.variable;
129129

130-
uint32_t rsizeScaled = vmask.rsize / vmask.rdivide;
130+
uint32_t rsizeScaled = vmask.rsize >> vmask.rshift;
131131
uint32_t maskLen = vmask.bitRep * vmask.maskRep * rsizeScaled;
132132
uint32_t fullMask = (uint64_t(1) << maskLen) - 1;
133133
uint32_t rep1Mask = (uint64_t(1) << (vmask.bitRep * rsizeScaled)) - 1;
@@ -136,7 +136,7 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
136136
auto flagType = flag.getType();
137137
auto mask0Type = getBytes(flagType) >= 4 ? DataType::uq : flagType;
138138

139-
if (vmask.rsize == 1 && vmask.rdivide == 1) {
139+
if (vmask.rsize == 1 && vmask.rshift == 0) {
140140
// Simple threshold comparison.
141141
offset += assignment.offset;
142142
if (flag.isARF())
@@ -152,11 +152,11 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
152152
auto mask0 = state.ra.alloc_sub(mask0Type, getHint(HintType::Bank1));
153153
auto mask = mask0.reinterpret(0, flagType);
154154
auto mindex = index;
155+
auto rdivide = 1 << vmask.rshift;
155156

156-
if (vmask.rdivide > 1) {
157-
if (!is_zero_or_pow2(vmask.rdivide)) stub();
158-
add(1 | sat, temp, mindex, -offset + vmask.rdivide - 1);
159-
shr(1, temp, temp, uint16_t(ilog2(vmask.rdivide)));
157+
if (vmask.rshift) {
158+
add(1 | sat, temp, mindex, -offset + rdivide - 1);
159+
shr(1, temp, temp, uint16_t(vmask.rshift));
160160
mindex = temp;
161161
offset = 0;
162162
}
@@ -169,7 +169,7 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
169169
mulConstant(1, temp, mindex, vmask.bitRep);
170170
mindex = temp;
171171
}
172-
uint16_t tshift = vmask.bitRep * (rsizeScaled + div_up(assignment.offset + offset, vmask.rdivide));
172+
uint16_t tshift = vmask.bitRep * (rsizeScaled + div_up(assignment.offset + offset, rdivide));
173173
add(1 | sat, temp, -mindex, tshift);
174174
if (tshift >= 32)
175175
min_(1, temp, temp, vmask.bitRep * rsizeScaled); // Ensure shift count doesn't overflow.

src/gpu/intel/jit/gemm/generator/pieces/register_block.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct MaskInfo {
3434
struct {
3535
uint8_t isFixed : 1; // = false (variable mask)
3636
uint8_t reverse : 1; // True to reverse mask.
37-
uint8_t rdivide : 6; // Amount by which to divide index before forming mask. Fractions are rounded up.
37+
uint8_t rshift : 6; // Power of 2 by which to divide index before forming mask. Fractions are rounded up.
3838
// Note maskRep * bitRep * (rsize >> rshift) = # mask bits.
3939
uint8_t rsize; // Maximum remainder value. (e.g. 16 if we need the last 4 bits of the index).
4040
uint8_t maskRep; // # of repetitions of mask pattern.

0 commit comments

Comments
 (0)