Skip to content

Commit 17fa424

Browse files
author
Tim Foley
authored
Fix output of groupshared with IR type system (shader-slang#492)
The basic problem was that the lowering logic was constructing (more or less) `Ptr<@groupshared X>` instead of `@GroupShared Ptr<X>`. There were also problems with passes not propagating through rates that should have been (e.g., legalization). I've added a test case to actually validate `groupshared` support.
1 parent c3a27c0 commit 17fa424

File tree

6 files changed

+114
-57
lines changed

6 files changed

+114
-57
lines changed

source/slang/emit.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,7 @@ struct EmitVisitor
12431243
auto rateQualifiedType = cast<IRRateQualifiedType>(type);
12441244
emitTypeImpl(rateQualifiedType->getValueType(), declarator);
12451245
}
1246+
break;
12461247

12471248
case kIROp_ArrayType:
12481249
emitArrayTypeImpl(cast<IRArrayType>(type), declarator);
@@ -3249,6 +3250,7 @@ struct EmitVisitor
32493250
auto valType = ptrType->getValueType();
32503251

32513252
auto name = getIRName(inst);
3253+
emitIRRateQualifiers(ctx, inst);
32523254
emitIRType(ctx, valType, name);
32533255
emit(";\n");
32543256
}
@@ -4155,7 +4157,7 @@ struct EmitVisitor
41554157
{
41564158
for (auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam())
41574159
{
4158-
emitIRType(ctx, pp->getDataType(), getIRName(pp));
4160+
emitIRType(ctx, pp->getFullType(), getIRName(pp));
41594161
emit(";\n");
41604162
}
41614163
}
@@ -5165,6 +5167,7 @@ struct EmitVisitor
51655167

51665168
emitIRVarModifiers(ctx, layout, varDecl, varType);
51675169

5170+
emitIRRateQualifiers(ctx, varDecl);
51685171
emitIRType(ctx, varType, getIRName(varDecl));
51695172

51705173
emitIRSemantics(ctx, varDecl);
@@ -5212,6 +5215,7 @@ struct EmitVisitor
52125215
emit("static ");
52135216
}
52145217
emit("const ");
5218+
emitIRRateQualifiers(ctx, valDecl);
52155219
emitIRType(ctx, valType, getIRName(valDecl));
52165220

52175221
if (valDecl->getFirstBlock())

source/slang/ir-legalize-types.cpp

+18-5
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ static LegalVal legalizeLocalVar(
600600
context,
601601
irLocalVar->getDataType()->getValueType());
602602

603+
auto originalRate = irLocalVar->getRate();
604+
603605
RefPtr<VarLayout> varLayout = findVarLayout(irLocalVar);
604606
RefPtr<TypeLayout> typeLayout = varLayout ? varLayout->typeLayout : nullptr;
605607

@@ -614,14 +616,25 @@ static LegalVal legalizeLocalVar(
614616
switch (maybeSimpleType.flavor)
615617
{
616618
case LegalType::Flavor::simple:
617-
// Easy case: the type is usable as-is, and we
618-
// should just do that.
619-
irLocalVar->setFullType(context->builder->getPtrType(
620-
maybeSimpleType.getSimple()));
621-
return LegalVal::simple(irLocalVar);
619+
{
620+
// Easy case: the type is usable as-is, and we
621+
// should just do that.
622+
auto type = maybeSimpleType.getSimple();
623+
type = context->builder->getPtrType(type);
624+
if( originalRate )
625+
{
626+
type = context->builder->getRateQualifiedType(
627+
originalRate,
628+
type);
629+
}
630+
irLocalVar->setFullType(type);
631+
return LegalVal::simple(irLocalVar);
632+
}
622633

623634
default:
624635
{
636+
// TODO: We don't handle rates in this path.
637+
625638
context->insertBeforeLocalVar = irLocalVar;
626639

627640
LegalVarChain* varChain = nullptr;

source/slang/ir.cpp

+25-5
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,9 @@ namespace Slang
548548
//
549549
IRParentInst* mergeCandidateParentsForHoistableInst(IRParentInst* left, IRParentInst* right)
550550
{
551+
// If the candidates are both the same, then who cares?
552+
if(left == right) return left;
553+
551554
// If either `left` or `right` is a block, then we need to be
552555
// a bit careful, because blocks can see other values just using
553556
// the dominance relationship, without a direct parent-child relationship.
@@ -4805,6 +4808,27 @@ namespace Slang
48054808
IRGlobalValueWithCode* clonedValue,
48064809
IRGlobalValueWithCode* originalValue);
48074810

4811+
IRRate* cloneRate(
4812+
IRSpecContextBase* context,
4813+
IRRate* rate)
4814+
{
4815+
return (IRRate*) cloneType(context, rate);
4816+
}
4817+
4818+
void maybeSetClonedRate(
4819+
IRSpecContextBase* context,
4820+
IRBuilder* builder,
4821+
IRInst* clonedValue,
4822+
IRInst* originalValue)
4823+
{
4824+
if(auto rate = originalValue->getRate() )
4825+
{
4826+
clonedValue->setFullType(builder->getRateQualifiedType(
4827+
cloneRate(context, rate),
4828+
clonedValue->getFullType()));
4829+
}
4830+
}
4831+
48084832
IRGlobalVar* cloneGlobalVarImpl(
48094833
IRSpecContextBase* context,
48104834
IRBuilder* builder,
@@ -4814,11 +4838,7 @@ namespace Slang
48144838
auto clonedVar = builder->createGlobalVar(
48154839
cloneType(context, originalVar->getDataType()->getValueType()));
48164840

4817-
if(auto rate = originalVar->getRate() )
4818-
{
4819-
clonedVar->setFullType(builder->getRateQualifiedType(
4820-
rate, clonedVar->getFullType()));
4821-
}
4841+
maybeSetClonedRate(context, builder, clonedVar, originalVar);
48224842

48234843
registerClonedValue(context, clonedVar, originalValues);
48244844

source/slang/lower-to-ir.cpp

+23-46
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,25 @@ void addVarDecorations(
11951195
}
11961196
}
11971197

1198+
/// If `decl` has a modifier that should turn into a
1199+
/// rate qualifier, then apply it to `inst`.
1200+
void maybeSetRate(
1201+
IRGenContext* context,
1202+
IRInst* inst,
1203+
Decl* decl)
1204+
{
1205+
auto builder = context->irBuilder;
1206+
1207+
if (decl->HasModifier<HLSLGroupSharedModifier>())
1208+
{
1209+
inst->setFullType(builder->getRateQualifiedType(
1210+
builder->getGroupSharedRate(),
1211+
inst->getFullType()));
1212+
}
1213+
}
1214+
1215+
1216+
11981217
LoweredValInfo createVar(
11991218
IRGenContext* context,
12001219
IRType* type,
@@ -1205,6 +1224,8 @@ LoweredValInfo createVar(
12051224

12061225
if (decl)
12071226
{
1227+
maybeSetRate(context, irAlloc, decl);
1228+
12081229
addVarDecorations(context, irAlloc, decl);
12091230

12101231
builder->addHighLevelDeclDecoration(irAlloc, decl);
@@ -3192,22 +3213,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
31923213
{
31933214
IRType* varType = lowerType(context, decl->getType());
31943215

3195-
if (decl->HasModifier<HLSLGroupSharedModifier>())
3196-
{
3197-
// TODO: here we are applying the rate qualifier to
3198-
// the *data type* of the variable, when we really
3199-
// should be applying the rate to the variable itself.
3200-
//
3201-
// This ends up making a distinction between
3202-
// `Ptr<@GroupShared X>` and `@GroupShared Ptr<X>`.
3203-
// The latter is more technically correct, but the
3204-
// code generation logic currently looks for the former.
3205-
3206-
varType = getBuilder()->getRateQualifiedType(
3207-
getBuilder()->getGroupSharedRate(),
3208-
varType);
3209-
}
3210-
32113216
auto builder = getBuilder();
32123217

32133218
IRGlobalValueWithCode* irGlobal = nullptr;
@@ -3226,6 +3231,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
32263231
}
32273232
irGlobal->mangledName = context->getSession()->getNameObj(getMangledName(decl));
32283233

3234+
maybeSetRate(context, irGlobal, decl);
3235+
32293236
if (decl)
32303237
{
32313238
builder->addHighLevelDeclDecoration(irGlobal, decl);
@@ -3300,36 +3307,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
33003307
// initializer expression a bit carefully (it should only
33013308
// be initialized on-demand at its first use).
33023309

3303-
// Some qualifiers on a variable will change how we allocate it,
3304-
// so we need to reflect that somehow. The first example
3305-
// we run into is the `groupshared` qualifier, which marks
3306-
// a variable in a compute shader as having per-group allocation
3307-
// rather than the traditional per-thread (or rather per-thread
3308-
// per-activation-record) allocation.
3309-
//
3310-
// Options include:
3311-
//
3312-
// - Use a distinct allocation opration, so that the type
3313-
// of the variable address/value is unchanged.
3314-
//
3315-
// - Add a notion of an "address space" to pointer types,
3316-
// so that we can allocate things in distinct spaces.
3317-
//
3318-
// - Add a notion of a "rate" so that we can declare a
3319-
// variable with a distinct rate.
3320-
//
3321-
// For now we might do the expedient thing and handle this
3322-
// via a notion of an "address space."
3323-
3324-
if (decl->HasModifier<HLSLGroupSharedModifier>())
3325-
{
3326-
// TODO: This logic is duplicated with the global-variable
3327-
// case. We should seek to share it.
3328-
varType = getBuilder()->getRateQualifiedType(
3329-
getBuilder()->getGroupSharedRate(),
3330-
varType);
3331-
}
3332-
33333310
LoweredValInfo varVal = createVar(context, varType, decl);
33343311

33353312
if( auto initExpr = decl->initExpr )

tests/compute/groupshared.slang

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// groupshared.slang
2+
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
4+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12
5+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
6+
7+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
8+
9+
#define THREAD_COUNT 4
10+
11+
groupshared int gA[THREAD_COUNT];
12+
int test(int val)
13+
{
14+
gA[val] = val;
15+
GroupMemoryBarrierWithGroupSync();
16+
val = gA[val ^ 1];
17+
18+
/* TODO: once function-scope `static` works
19+
static groupshared int gB[THREAD_COUNT];
20+
21+
gB[val] = val;
22+
GroupMemoryBarrierWithGroupSync();
23+
val = gB[val ^ 2];
24+
*/
25+
26+
return val;
27+
}
28+
29+
RWStructuredBuffer<int> gBuffer;
30+
31+
[numthreads(THREAD_COUNT, 1, 1)]
32+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
33+
{
34+
uint tid = dispatchThreadID.x;
35+
36+
int val = int(tid);
37+
val = test(val);
38+
gBuffer[tid] = val;
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1
2+
0
3+
3
4+
2

0 commit comments

Comments
 (0)