Skip to content

Commit 3fe4a77

Browse files
authored
Fix crash when using optional type in a generic. (shader-slang#4341)
1 parent 5da06d4 commit 3fe4a77

File tree

3 files changed

+102
-25
lines changed

3 files changed

+102
-25
lines changed

source/slang/slang-ir-lower-optional-type.cpp

+38-25
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ namespace Slang
1515
InstWorkList workList;
1616
InstHashSet workListSet;
1717

18+
IRGeneric* genericOptionalStructType = nullptr;
19+
IRStructKey* valueKey = nullptr;
20+
IRStructKey* hasValueKey = nullptr;
21+
1822
OptionalTypeLoweringContext(IRModule* inModule)
1923
:module(inModule), workList(inModule), workListSet(inModule)
2024
{}
@@ -24,8 +28,6 @@ namespace Slang
2428
IRType* optionalType = nullptr;
2529
IRType* valueType = nullptr;
2630
IRType* loweredType = nullptr;
27-
IRStructField* valueField = nullptr;
28-
IRStructField* hasValueField = nullptr;
2931
};
3032
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> mapLoweredTypeToOptionalTypeInfo;
3133
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> loweredOptionalTypes;
@@ -38,6 +40,34 @@ namespace Slang
3840
return type;
3941
}
4042

43+
IRInst* getOrCreateGenericOptionalStruct()
44+
{
45+
if (genericOptionalStructType)
46+
return genericOptionalStructType;
47+
IRBuilder builder(module);
48+
builder.setInsertInto(module->getModuleInst());
49+
50+
valueKey = builder.createStructKey();
51+
builder.addNameHintDecoration(valueKey, UnownedStringSlice("value"));
52+
hasValueKey = builder.createStructKey();
53+
builder.addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));
54+
55+
genericOptionalStructType = builder.emitGeneric();
56+
builder.addNameHintDecoration(genericOptionalStructType, UnownedStringSlice("_slang_Optional"));
57+
58+
builder.setInsertInto(genericOptionalStructType);
59+
auto block = builder.emitBlock();
60+
auto typeParam = builder.emitParam(builder.getTypeKind());
61+
auto structType = builder.createStructType();
62+
builder.addNameHintDecoration(structType, UnownedStringSlice("_slang_Optional"));
63+
builder.createStructField(structType, valueKey, (IRType*)typeParam);
64+
builder.createStructField(structType, hasValueKey, builder.getBoolType());
65+
builder.setInsertInto(block);
66+
builder.emitReturn(structType);
67+
genericOptionalStructType->setFullType(builder.getTypeKind());
68+
return genericOptionalStructType;
69+
}
70+
4171
bool typeHasNullValue(IRInst* type)
4272
{
4373
switch (type->getOp())
@@ -78,19 +108,10 @@ namespace Slang
78108
}
79109
else
80110
{
81-
auto structType = builder->createStructType();
82-
info->loweredType = structType;
83-
builder->addNameHintDecoration(structType, UnownedStringSlice("OptionalType"));
84-
85-
info->valueType = valueType;
86-
auto valueKey = builder->createStructKey();
87-
builder->addNameHintDecoration(valueKey, UnownedStringSlice("value"));
88-
info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType);
89-
90-
auto boolType = builder->getBoolType();
91-
auto hasValueKey = builder->createStructKey();
92-
builder->addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));
93-
info->hasValueField = builder->createStructField(structType, hasValueKey, (IRType*)boolType);
111+
auto genericType = getOrCreateGenericOptionalStruct();
112+
IRInst* args[] = { valueType };
113+
auto specializedType = builder->emitSpecializeInst(builder->getTypeKind(), genericType, 1, args);
114+
info->loweredType = (IRType*)specializedType;
94115
}
95116
mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info;
96117
loweredOptionalTypes[type] = info;
@@ -100,12 +121,6 @@ namespace Slang
100121
void addToWorkList(
101122
IRInst* inst)
102123
{
103-
for (auto ii = inst->getParent(); ii; ii = ii->getParent())
104-
{
105-
if (as<IRGeneric>(ii))
106-
return;
107-
}
108-
109124
if (workListSet.contains(inst))
110125
return;
111126

@@ -169,7 +184,7 @@ namespace Slang
169184
result = builder->emitFieldExtract(
170185
builder->getBoolType(),
171186
optionalInst,
172-
loweredOptionalTypeInfo->hasValueField->getKey());
187+
hasValueKey);
173188
}
174189
else
175190
{
@@ -201,11 +216,10 @@ namespace Slang
201216
if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType)
202217
{
203218
SLANG_ASSERT(loweredOptionalTypeInfo);
204-
SLANG_ASSERT(loweredOptionalTypeInfo->valueField);
205219
auto getElement = builder->emitFieldExtract(
206220
loweredOptionalTypeInfo->valueType,
207221
base,
208-
loweredOptionalTypeInfo->valueField->getKey());
222+
valueKey);
209223
inst->replaceUsesWith(getElement);
210224
}
211225
else
@@ -257,7 +271,6 @@ namespace Slang
257271
while (workList.getCount() != 0)
258272
{
259273
IRInst* inst = workList.getLast();
260-
261274
workList.removeLast();
262275
workListSet.remove(inst);
263276

tests/bugs/optional-generic.slang

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk
3+
4+
5+
Optional<T> genFunc<T : IArithmetic>(T v)
6+
{
7+
if (v is int)
8+
return v;
9+
return none;
10+
}
11+
12+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer
13+
14+
RWStructuredBuffer<int> buffer;
15+
16+
[numthreads(1,1,1)]
17+
void computeMain()
18+
{
19+
// BUF: 2
20+
buffer[0] = genFunc(2).value;
21+
}
22+

tests/bugs/optional.slang

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk
3+
4+
interface IFoo
5+
{
6+
void foo();
7+
}
8+
9+
struct S : IFoo { int x; void foo(); }
10+
11+
struct P
12+
{
13+
IFoo f;
14+
}
15+
struct Tr
16+
{
17+
int test<T:IArithmetic>(T t, inout P p)
18+
{
19+
const IFoo hit = p.f;
20+
let castResult = hit as S;
21+
if (!castResult.hasValue)
22+
return 0;
23+
return castResult.value.x;
24+
}
25+
}
26+
27+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer
28+
29+
RWStructuredBuffer<int> buffer;
30+
31+
[numthreads(1,1,1)]
32+
void computeMain()
33+
{
34+
P p;
35+
S s;
36+
s.x = 2;
37+
p.f = s;
38+
Tr tt;
39+
// BUF: 2
40+
buffer[0] = tt.test(0, p);
41+
}
42+

0 commit comments

Comments
 (0)