Skip to content

Commit 8a292cf

Browse files
authored
SPIRV: Fix matrix layout tests. (shader-slang#3137)
* SPIRV: Fix matrix layout tests. * Remove spaces. * Disable debug output. * Fix. * Update expected-failure list. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 80c8f13 commit 8a292cf

5 files changed

+169
-62
lines changed

source/slang/hlsl.meta.slang

+3
Original file line numberDiff line numberDiff line change
@@ -3022,6 +3022,7 @@ T mul(vector<T, N> x, vector<T, N> y)
30223022
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
30233023
__target_intrinsic(hlsl)
30243024
__target_intrinsic(glsl, "($1 * $0)")
3025+
__target_intrinsic(spirv, "OpMatrixTimesVector resultType resultId _1 _0")
30253026
[__readNone]
30263027
vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
30273028
{
@@ -3078,6 +3079,7 @@ vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
30783079
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
30793080
__target_intrinsic(hlsl)
30803081
__target_intrinsic(glsl, "($1 * $0)")
3082+
__target_intrinsic(spirv, "OpVectorTimesMatrix resultType resultId _1 _0")
30813083
[__readNone]
30823084
vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
30833085
{
@@ -3134,6 +3136,7 @@ vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
31343136
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
31353137
__target_intrinsic(hlsl)
31363138
__target_intrinsic(glsl, "($1 * $0)")
3139+
__target_intrinsic(spirv, "OpMatrixTimesMatrix resultType resultId _1 _0")
31373140
[__readNone]
31383141
matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left)
31393142
{

source/slang/slang-emit-spirv.cpp

+77-27
Original file line numberDiff line numberDiff line change
@@ -1240,9 +1240,9 @@ struct SPIRVEmitContext
12401240
auto matrixType = static_cast<IRMatrixType*>(inst);
12411241
auto vectorSpvType = ensureVectorType(
12421242
static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
1243-
static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(),
1243+
static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(),
12441244
nullptr);
1245-
const auto columnCount = static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue();
1245+
const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
12461246
auto matrixSPVType = emitOpTypeMatrix(
12471247
inst,
12481248
vectorSpvType,
@@ -2147,6 +2147,16 @@ struct SPIRVEmitContext
21472147
break;
21482148

21492149
case kIROp_SPIRVBufferBlockDecoration:
2150+
{
2151+
emitOpDecorate(
2152+
getSection(SpvLogicalSectionID::Annotations),
2153+
decoration,
2154+
dstID,
2155+
SpvDecorationBufferBlock
2156+
);
2157+
}
2158+
break;
2159+
case kIROp_SPIRVBlockDecoration:
21502160
{
21512161
emitOpDecorate(
21522162
getSection(SpvLogicalSectionID::Annotations),
@@ -2206,25 +2216,55 @@ struct SPIRVEmitContext
22062216
}
22072217
if (matrixType)
22082218
{
2209-
IRSizeAndAlignment matrixSize;
2210-
getSizeAndAlignment(IRTypeLayoutRules::get(layoutRuleName), matrixType, &matrixSize);
2219+
// SPIRV sepc on MatrixStride:
2220+
// Applies only to a member of a structure type.Only valid on a
2221+
// matrix or array whose most basic element is a matrix.Matrix
2222+
// Stride is an unsigned 32 - bit integer specifying the stride
2223+
// of the rows in a RowMajor - decorated matrix or columns in a
2224+
// ColMajor - decorated matrix.
2225+
IRIntegerValue matrixStride = 0;
2226+
auto rule = IRTypeLayoutRules::get(layoutRuleName);
2227+
IRSizeAndAlignment elementSizeAlignment;
2228+
getSizeAndAlignment(rule, matrixType->getElementType(), &elementSizeAlignment);
2229+
22112230
// Reminder: the meaning of row/column major layout
22122231
// in our semantics is the *opposite* of what GLSL/SPIRV
22132232
// calls them, because what they call "columns"
22142233
// are what we call "rows."
22152234
//
2216-
emitOpMemberDecorate(
2217-
getSection(SpvLogicalSectionID::Annotations),
2218-
nullptr,
2219-
spvStructID,
2220-
SpvLiteralInteger::from32(id),
2221-
getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR? SpvDecorationRowMajor : SpvDecorationColMajor);
2235+
if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
2236+
{
2237+
emitOpMemberDecorate(
2238+
getSection(SpvLogicalSectionID::Annotations),
2239+
nullptr,
2240+
spvStructID,
2241+
SpvLiteralInteger::from32(id),
2242+
SpvDecorationRowMajor);
2243+
2244+
auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getRowCount()));
2245+
vectorSize = rule->alignCompositeElement(vectorSize);
2246+
matrixStride = vectorSize.getStride();
2247+
}
2248+
else
2249+
{
2250+
emitOpMemberDecorate(
2251+
getSection(SpvLogicalSectionID::Annotations),
2252+
nullptr,
2253+
spvStructID,
2254+
SpvLiteralInteger::from32(id),
2255+
SpvDecorationColMajor);
2256+
2257+
auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getColumnCount()));
2258+
vectorSize = rule->alignCompositeElement(vectorSize);
2259+
matrixStride = vectorSize.getStride();
2260+
}
2261+
22222262
emitOpMemberDecorateMatrixStride(
22232263
getSection(SpvLogicalSectionID::Annotations),
22242264
nullptr,
22252265
spvStructID,
22262266
SpvLiteralInteger::from32(id),
2227-
SpvLiteralInteger::from32((int32_t)matrixSize.getStride()));
2267+
SpvLiteralInteger::from32((int32_t)matrixStride));
22282268
}
22292269
id++;
22302270
}
@@ -2835,6 +2875,7 @@ struct SPIRVEmitContext
28352875

28362876
SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst)
28372877
{
2878+
// Note: SPIRV only supports the case where `index` is constant.
28382879
auto base = inst->getBase();
28392880
const auto baseTy = base->getDataType();
28402881
SLANG_ASSERT(
@@ -2845,19 +2886,14 @@ struct SPIRVEmitContext
28452886

28462887
IRBuilder builder(m_irModule);
28472888
builder.setInsertBefore(inst);
2889+
auto index = getIntVal(inst->getIndex());
28482890

2849-
auto ptr = emitOpAccessChain(
2850-
parent,
2851-
nullptr,
2852-
builder.getPtrType(inst->getFullType()),
2853-
inst->getBase(),
2854-
makeArray(inst->getIndex())
2855-
);
2856-
return emitOpLoad(
2891+
return emitOpCompositeExtract(
28572892
parent,
28582893
inst,
28592894
inst->getFullType(),
2860-
ptr
2895+
inst->getBase(),
2896+
makeArray(SpvLiteralInteger::from32((int32_t)index))
28612897
);
28622898
}
28632899

@@ -3291,17 +3327,31 @@ SlangResult emitSPIRVFromIR(
32913327
auto targetRequest = codeGenContext->getTargetReq();
32923328
auto sink = codeGenContext->getSink();
32933329

3330+
#if 0
3331+
{
3332+
DiagnosticSinkWriter writer(codeGenContext->getSink());
3333+
dumpIR(
3334+
irModule,
3335+
{ IRDumpOptions::Mode::Simplified, 0 },
3336+
"BEFORE SPIR-V LEGALIZE",
3337+
codeGenContext->getSourceManager(),
3338+
&writer);
3339+
}
3340+
#endif
3341+
32943342
SPIRVEmitContext context(irModule, targetRequest, sink);
32953343
legalizeIRForSPIRV(&context, irModule, irEntryPoints, codeGenContext);
32963344

32973345
#if 0
3298-
DiagnosticSinkWriter writer(codeGenContext->getSink());
3299-
dumpIR(
3300-
irModule,
3301-
{IRDumpOptions::Mode::Simplified, 0},
3302-
"BEFORE SPIR-V EMIT",
3303-
codeGenContext->getSourceManager(),
3304-
&writer);
3346+
{
3347+
DiagnosticSinkWriter writer(codeGenContext->getSink());
3348+
dumpIR(
3349+
irModule,
3350+
{ IRDumpOptions::Mode::Simplified, 0 },
3351+
"BEFORE SPIR-V EMIT",
3352+
codeGenContext->getSourceManager(),
3353+
&writer);
3354+
}
33053355
#endif
33063356

33073357
context.emitFrontMatter();

source/slang/slang-ir-inst-defs.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -904,11 +904,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
904904
/// Overrides the floating mode for the target function
905905
INST(FloatingPointModeOverrideDecoration, FloatingPointModeOverride, 1, 0)
906906

907-
/// Marks a struct type as being used as a structured buffer block.
908907
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
909908
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
910909

911-
INST_RANGE(Decoration, HighLevelDeclDecoration, SPIRVBufferBlockDecoration)
910+
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `Block` decoration.
911+
INST(SPIRVBlockDecoration, spvBlock, 0, 0)
912+
913+
914+
INST_RANGE(Decoration, HighLevelDeclDecoration, SPIRVBlockDecoration)
912915

913916
//
914917

source/slang/slang-ir-spirv-legalize.cpp

+81-24
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "slang-glsl-extension-tracker.h"
1111
#include "slang-ir-lower-buffer-element-type.h"
1212
#include "slang-ir-layout.h"
13+
#include "slang-ir-util.h"
1314

1415
namespace Slang
1516
{
@@ -52,6 +53,36 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
5253
{
5354
}
5455

56+
// Wraps the element type of a constant buffer or parameter block in a struct if it is not already a struct,
57+
// returns the newly created struct type.
58+
IRType* wrapConstantBufferElement(IRInst* cbParamInst)
59+
{
60+
auto innerType = as<IRParameterGroupType>(cbParamInst->getDataType())->getElementType();
61+
IRBuilder builder(cbParamInst);
62+
builder.setInsertBefore(cbParamInst);
63+
auto structType = builder.createStructType();
64+
StringBuilder sb;
65+
sb << "cbuffer_";
66+
getTypeNameHint(sb, innerType);
67+
sb << "_t";
68+
builder.addNameHintDecoration(structType, sb.produceString().getUnownedSlice());
69+
auto key = builder.createStructKey();
70+
builder.createStructField(structType, key, innerType);
71+
builder.setInsertBefore(cbParamInst);
72+
auto newCbType = builder.getType(cbParamInst->getDataType()->getOp(), structType);
73+
cbParamInst->setFullType(newCbType);
74+
auto rules = getTypeLayoutRuleForBuffer(m_sharedContext->m_targetRequest, cbParamInst->getDataType());
75+
IRSizeAndAlignment sizeAlignment;
76+
getSizeAndAlignment(rules, structType, &sizeAlignment);
77+
traverseUses(cbParamInst, [&](IRUse* use)
78+
{
79+
builder.setInsertBefore(use->getUser());
80+
auto addr = builder.emitFieldAddress(builder.getPtrType(kIROp_PtrType, innerType, SpvStorageClassUniform), cbParamInst, key);
81+
use->set(addr);
82+
});
83+
return structType;
84+
}
85+
5586
void processGlobalParam(IRGlobalParam* inst)
5687
{
5788
// If the global param is not a pointer type, make it so and insert explicit load insts.
@@ -86,34 +117,45 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
86117
}
87118

88119
// Strip any HLSL wrappers
120+
IRBuilder builder(m_sharedContext->m_irModule);
121+
bool needLoad = true;
89122
auto innerType = inst->getFullType();
90-
if(const auto constantBufferType = as<IRConstantBufferType>(innerType))
91-
{
92-
innerType = constantBufferType->getElementType();
93-
storageClass = SpvStorageClassUniform;
94-
}
95-
else if (auto paramBlockType = as<IRParameterBlockType>(innerType))
123+
if (as<IRConstantBufferType>(innerType) || as<IRParameterBlockType>(innerType))
96124
{
97-
innerType = paramBlockType->getElementType();
125+
innerType = as<IRUniformParameterGroupType>(innerType)->getElementType();
98126
storageClass = SpvStorageClassUniform;
127+
// Constant buffer is already treated like a pointer type, and
128+
// we are not adding another layer of indirection when replacing it
129+
// with a pointer type. Therefore we don't need to insert a load at
130+
// use sites.
131+
needLoad = false;
132+
// If inner element type is not a struct type, we need to wrap it with
133+
// a struct.
134+
if (!as<IRStructType>(innerType))
135+
{
136+
innerType = wrapConstantBufferElement(inst);
137+
}
138+
builder.addDecoration(innerType, kIROp_SPIRVBlockDecoration);
99139
}
100140

101141
// Make a pointer type of storageClass.
102-
IRBuilder builder(m_sharedContext->m_irModule);
103142
builder.setInsertBefore(inst);
104143
ptrType = builder.getPtrType(kIROp_PtrType, innerType, storageClass);
105144
inst->setFullType(ptrType);
106-
// Insert an explicit load at each use site.
107-
List<IRUse*> uses;
108-
for (auto use = inst->firstUse; use; use = use->nextUse)
109-
{
110-
uses.add(use);
111-
}
112-
for (auto use : uses)
145+
if (needLoad)
113146
{
114-
builder.setInsertBefore(use->getUser());
115-
auto loadedValue = builder.emitLoad(inst);
116-
use->set(loadedValue);
147+
// Insert an explicit load at each use site.
148+
List<IRUse*> uses;
149+
for (auto use = inst->firstUse; use; use = use->nextUse)
150+
{
151+
uses.add(use);
152+
}
153+
for (auto use : uses)
154+
{
155+
builder.setInsertBefore(use->getUser());
156+
auto loadedValue = builder.emitLoad(inst);
157+
use->set(loadedValue);
158+
}
117159
}
118160
}
119161
processGlobalVar(inst);
@@ -206,22 +248,37 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
206248
}
207249
}
208250

209-
// Replace getElement(x, i) with, y = store(x); p = getElementPtr(y, i); load(p)
210-
// SPIR-V has no support for dynamic indexing into values like we do.
251+
Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar;
252+
253+
// Replace getElement(x, i) with, y = store(x); p = getElementPtr(y, i); load(p),
254+
// when i is not a constant. SPIR-V has no support for dynamic indexing into values like we do.
211255
// It may be advantageous however to do this further up the pipeline
212256
void processGetElement(IRGetElement* inst)
213257
{
214-
const auto x = inst->getBase();
258+
IRInst* x = nullptr;
215259
List<IRInst*> indices;
216260
IRGetElement* c = inst;
217261
do
218262
{
263+
if (as<IRIntLit>(c->getIndex()))
264+
break;
265+
x = c->getBase();
219266
indices.add(c->getIndex());
220267
} while(c = as<IRGetElement>(c->getBase()), c);
268+
269+
if (!x)
270+
return;
271+
221272
IRBuilder builder(m_sharedContext->m_irModule);
273+
IRInst* y = nullptr;
274+
if (!m_mapArrayValueToVar.tryGetValue(x, y))
275+
{
276+
setInsertAfterOrdinaryInst(&builder, x);
277+
y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
278+
builder.emitStore(y, x);
279+
m_mapArrayValueToVar.set(x, y);
280+
}
222281
builder.setInsertBefore(inst);
223-
IRInst* y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
224-
builder.emitStore(y, x);
225282
for(Index i = indices.getCount() - 1; i >= 0; --i)
226283
y = builder.emitElementAddress(y, indices[i]);
227284
const auto newInst = builder.emitLoad(y);
@@ -367,7 +424,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
367424
break;
368425
}
369426
builder.addNameHintDecoration(structType, nameSb.getUnownedSlice());
370-
builder.addDecoration(structType, kIROp_SPIRVBufferBlockDecoration);
427+
builder.addDecoration(structType, kIROp_SPIRVBlockDecoration);
371428
inst->replaceUsesWith(ptrType);
372429
inst->removeAndDeallocate();
373430
addUsersToWorkList(ptrType);

0 commit comments

Comments
 (0)