Skip to content

Commit 10190da

Browse files
author
Tim Foley
authored
Handle structure initializers in IR type legalization (shader-slang#567)
Fixes shader-slang#566 The basic problem here is that the front-end translates a structure initializer-list expression into a `makeStruct` instruction (with one argument per field), but the IR type legalization logic wasn't handling the case where a `makeStruct` is used to construct a struct value that needs to get split by legalization. The implementation is relatively straightforward, and like the other cases of instruction legalization for compound types, it follows the shape of the `LegalType`/`LegalVal` cases. The one interesting bit is that we need to be a bit careful and filter the single argument list for `makeStruct` into two in the case where we generate a "pair" type for something that has both "ordinary" and "special" (resource) fields. Luckily the `PairInfo` data that was generated by type legalization has exactly the information we need (by design). This change does not address several issues that could be handled in follow-on changes: * The `makeArray` instruction will face similar issues if it is applied to a type that requires legalization: we'd need to turn an array of `LegalVal`s into a bunch of distinct arrays. * The error message when we hit the unimplemented case here isn't great. Ideally we should provide the line number of the instruction that fails in an error message when legalization fails. This change tries to focus narrowly on the bug at hand, and leave these issues for later changes.
1 parent e2c2c22 commit 10190da

File tree

4 files changed

+164
-3
lines changed

4 files changed

+164
-3
lines changed

source/slang/ir-legalize-types.cpp

+123
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,122 @@ static LegalVal legalizeGetElementPtr(
669669
indexOperand);
670670
}
671671

672+
static LegalVal legalizeMakeStruct(
673+
IRTypeLegalizationContext* context,
674+
LegalType legalType,
675+
LegalVal const* legalArgs,
676+
UInt argCount)
677+
{
678+
auto builder = context->builder;
679+
680+
switch(legalType.flavor)
681+
{
682+
case LegalType::Flavor::simple:
683+
{
684+
List<IRInst*> args;
685+
for(UInt aa = 0; aa < argCount; ++aa)
686+
{
687+
// Note: we assume that all the arguments
688+
// must be simple here, because otherwise
689+
// the `struct` type with them as fields
690+
// would not be simple...
691+
//
692+
args.Add(legalArgs[aa].getSimple());
693+
}
694+
return LegalVal::simple(
695+
builder->emitMakeStruct(
696+
legalType.getSimple(),
697+
argCount,
698+
args.Buffer()));
699+
}
700+
701+
case LegalType::Flavor::pair:
702+
{
703+
// There are two sides, the ordinary and the special,
704+
// and we basically just dispatch to both of them.
705+
auto pairType = legalType.getPair();
706+
auto pairInfo = pairType->pairInfo;
707+
LegalType ordinaryType = pairType->ordinaryType;
708+
LegalType specialType = pairType->specialType;
709+
710+
List<LegalVal> ordinaryArgs;
711+
List<LegalVal> specialArgs;
712+
UInt argCounter = 0;
713+
for(auto ee : pairInfo->elements)
714+
{
715+
UInt argIndex = argCounter++;
716+
LegalVal arg = legalArgs[argIndex];
717+
718+
if((ee.flags & Slang::PairInfo::kFlag_hasOrdinaryAndSpecial) == Slang::PairInfo::kFlag_hasOrdinaryAndSpecial)
719+
{
720+
// The field is itself a pair type, so we expect
721+
// the argument value to be one too...
722+
auto argPair = arg.getPair();
723+
ordinaryArgs.Add(argPair->ordinaryVal);
724+
specialArgs.Add(argPair->specialVal);
725+
}
726+
else if(ee.flags & Slang::PairInfo::kFlag_hasOrdinary)
727+
{
728+
ordinaryArgs.Add(arg);
729+
}
730+
else if(ee.flags & Slang::PairInfo::kFlag_hasSpecial)
731+
{
732+
specialArgs.Add(arg);
733+
}
734+
}
735+
736+
LegalVal ordinaryVal = legalizeMakeStruct(
737+
context,
738+
ordinaryType,
739+
ordinaryArgs.Buffer(),
740+
ordinaryArgs.Count());
741+
742+
LegalVal specialVal = legalizeMakeStruct(
743+
context,
744+
specialType,
745+
specialArgs.Buffer(),
746+
specialArgs.Count());
747+
748+
return LegalVal::pair(ordinaryVal, specialVal, pairInfo);
749+
}
750+
break;
751+
752+
case LegalType::Flavor::tuple:
753+
{
754+
// We are constructing a tuple of values from
755+
// the individual fields. We need to identify
756+
// for each tuple element what field it uses,
757+
// and then extract that field's value.
758+
759+
auto tupleType = legalType.getTuple();
760+
761+
RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal();
762+
UInt argCounter = 0;
763+
for(auto typeElem : tupleType->elements)
764+
{
765+
auto elemKey = typeElem.key;
766+
UInt argIndex = argCounter++;
767+
SLANG_ASSERT(argIndex < argCount);
768+
769+
LegalVal argVal = legalArgs[argIndex];
770+
771+
TuplePseudoVal::Element resElem;
772+
resElem.key = elemKey;
773+
resElem.val = argVal;
774+
775+
resTupleInfo->elements.Add(resElem);
776+
}
777+
return LegalVal::tuple(resTupleInfo);
778+
}
779+
780+
default:
781+
SLANG_UNEXPECTED("unhandled");
782+
UNREACHABLE_RETURN(LegalVal());
783+
}
784+
}
785+
786+
787+
672788
static LegalVal legalizeInst(
673789
IRTypeLegalizationContext* context,
674790
IRInst* inst,
@@ -695,6 +811,13 @@ static LegalVal legalizeInst(
695811
case kIROp_Call:
696812
return legalizeCall(context, (IRCall*)inst);
697813

814+
case kIROp_makeStruct:
815+
return legalizeMakeStruct(
816+
context,
817+
type,
818+
args,
819+
inst->getOperandCount());
820+
698821
default:
699822
// TODO: produce a user-visible diagnostic here
700823
SLANG_UNEXPECTED("non-simple operand(s)!");

source/slang/legalize-types.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -290,15 +290,15 @@ struct LegalVal
290290
return result;
291291
}
292292

293-
IRInst* getSimple()
293+
IRInst* getSimple() const
294294
{
295295
SLANG_ASSERT(flavor == Flavor::simple);
296296
return irValue;
297297
}
298298

299299
static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal);
300300

301-
RefPtr<TuplePseudoVal> getTuple()
301+
RefPtr<TuplePseudoVal> getTuple() const
302302
{
303303
SLANG_ASSERT(flavor == Flavor::tuple);
304304
return obj.As<TuplePseudoVal>();
@@ -313,7 +313,7 @@ struct LegalVal
313313
LegalVal const& specialVal,
314314
RefPtr<PairInfo> pairInfo);
315315

316-
RefPtr<PairPseudoVal> getPair()
316+
RefPtr<PairPseudoVal> getPair() const
317317
{
318318
SLANG_ASSERT(flavor == Flavor::pair);
319319
return obj.As<PairPseudoVal>();

tests/bugs/gh-566.slang

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// legalize-struct-init.slang
2+
3+
//TEST(compute):COMPARE_COMPUTE:
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
5+
//TEST_INPUT:ubuffer(data=[4 3 2 1], stride=4):dxbinding(1),glbinding(1)
6+
7+
8+
RWStructuredBuffer<uint> outputBuffer;
9+
RWStructuredBuffer<uint> inputBuffer;
10+
11+
struct Helper
12+
{
13+
RWStructuredBuffer<uint> o;
14+
RWStructuredBuffer<uint> i;
15+
uint t;
16+
};
17+
18+
void test(Helper h)
19+
{
20+
h.o[h.t] = h.i[h.t] * 16 + h.t;
21+
}
22+
23+
void test(uint t)
24+
{
25+
Helper h = { outputBuffer, inputBuffer, t };
26+
test(h);
27+
}
28+
29+
[numthreads(4, 1, 1)]
30+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
31+
{
32+
uint tid = dispatchThreadID.x;
33+
test(tid);
34+
}

tests/bugs/gh-566.slang.expected.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
40
2+
31
3+
22
4+
13

0 commit comments

Comments
 (0)