Skip to content

Commit b37a777

Browse files
csyongheTim Foley
and
Tim Foley
authored
Lower existential types. (shader-slang#1497)
Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
1 parent 99366e7 commit b37a777

13 files changed

+273
-7
lines changed

source/slang/slang-ir-generics-lowering-context.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,15 @@ namespace Slang
157157
return lowerAssociatedType(builder, paramType);
158158
}
159159
case kIROp_InterfaceType:
160+
{
161+
// An existential type translates into a tuple of (AnyValue, WitnessTable, RTTI*)
160162
anyValueSize = getInterfaceAnyValueSize(paramType, paramType->sourceLoc);
161-
return builder->getAnyValueType(anyValueSize);
163+
auto anyValueType = builder->getAnyValueType(anyValueSize);
164+
auto witnessTableType = builder->getWitnessTableType((IRType*)paramType);
165+
auto rttiType = builder->getPtrType(builder->getRTTIType());
166+
auto tupleType = builder->getTupleType(anyValueType, witnessTableType, rttiType);
167+
return tupleType;
168+
}
162169
case kIROp_lookup_interface_method:
163170
{
164171
auto lookupInterface = static_cast<IRLookupWitnessMethod*>(paramType);
+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// slang-ir-lower-generic-existential.cpp
2+
3+
#include "slang-ir-lower-existential.h"
4+
#include "slang-ir-generics-lowering-context.h"
5+
#include "slang-ir.h"
6+
#include "slang-ir-insts.h"
7+
8+
namespace Slang
9+
{
10+
struct ExistentialLoweringContext
11+
{
12+
SharedGenericsLoweringContext* sharedContext;
13+
14+
void processMakeExistential(IRMakeExistential* inst)
15+
{
16+
IRBuilder builderStorage;
17+
auto builder = &builderStorage;
18+
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
19+
builder->setInsertBefore(inst);
20+
21+
auto value = inst->getWrappedValue();
22+
auto valueType = sharedContext->lowerType(builder, value->getDataType());
23+
auto witnessTableType = cast<IRWitnessTableType>(inst->getWitnessTable()->getDataType());
24+
auto interfaceType = witnessTableType->getConformanceType();
25+
auto anyValueSize = sharedContext->getInterfaceAnyValueSize(interfaceType, inst->sourceLoc);
26+
auto anyValueType = builder->getAnyValueType(anyValueSize);
27+
auto rttiType = builder->getPtrType(builder->getRTTIType());
28+
auto tupleType = builder->getTupleType(anyValueType, witnessTableType, rttiType);
29+
30+
IRInst* rttiObject = nullptr;
31+
if (valueType->op != kIROp_AnyValueType)
32+
{
33+
rttiObject = sharedContext->maybeEmitRTTIObject(valueType);
34+
rttiObject = builder->emitGetAddress(
35+
builder->getPtrType(builder->getRTTIType()),
36+
rttiObject);
37+
}
38+
else
39+
{
40+
rttiObject = valueType;
41+
}
42+
IRInst* packedValue = value;
43+
if (valueType->op != kIROp_AnyValueType)
44+
packedValue = builder->emitPackAnyValue(anyValueType, value);
45+
IRInst* tupleArgs[] = { packedValue, inst->getWitnessTable(), rttiObject };
46+
auto tuple = builder->emitMakeTuple(tupleType, 3, tupleArgs);
47+
inst->replaceUsesWith(tuple);
48+
inst->removeAndDeallocate();
49+
}
50+
51+
IRInst* extractTupleElement(IRBuilder* builder, IRInst* value, UInt index)
52+
{
53+
auto tupleType = cast<IRTupleType>(sharedContext->lowerType(builder, value->getDataType()));
54+
auto getElement = builder->emitGetTupleElement(
55+
(IRType*)tupleType->getOperand(index),
56+
value,
57+
index);
58+
return getElement;
59+
}
60+
61+
void processExtractExistentialElement(IRInst* extractInst, UInt elementId)
62+
{
63+
IRBuilder builderStorage;
64+
auto builder = &builderStorage;
65+
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
66+
builder->setInsertBefore(extractInst);
67+
68+
auto element = extractTupleElement(builder, extractInst->getOperand(0), elementId);
69+
extractInst->replaceUsesWith(element);
70+
extractInst->removeAndDeallocate();
71+
}
72+
73+
void processExtractExistentialValue(IRExtractExistentialValue* inst)
74+
{
75+
processExtractExistentialElement(inst, 0);
76+
}
77+
78+
void processExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst)
79+
{
80+
processExtractExistentialElement(inst, 1);
81+
}
82+
83+
void processExtractExistentialType(IRExtractExistentialType* inst)
84+
{
85+
processExtractExistentialElement(inst, 2);
86+
}
87+
88+
void processInst(IRInst* inst)
89+
{
90+
if (auto makeExistential = as<IRMakeExistential>(inst))
91+
{
92+
processMakeExistential(makeExistential);
93+
}
94+
else if (auto extractExistentialVal = as<IRExtractExistentialValue>(inst))
95+
{
96+
processExtractExistentialValue(extractExistentialVal);
97+
}
98+
else if (auto extractExistentialType = as<IRExtractExistentialType>(inst))
99+
{
100+
processExtractExistentialType(extractExistentialType);
101+
}
102+
else if (auto extractExistentialWitnessTable = as<IRExtractExistentialWitnessTable>(inst))
103+
{
104+
processExtractExistentialWitnessTable(extractExistentialWitnessTable);
105+
}
106+
}
107+
108+
void processModule()
109+
{
110+
// We start by initializing our shared IR building state,
111+
// since we will re-use that state for any code we
112+
// generate along the way.
113+
//
114+
SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
115+
sharedBuilder->module = sharedContext->module;
116+
sharedBuilder->session = sharedContext->module->session;
117+
118+
sharedContext->addToWorkList(sharedContext->module->getModuleInst());
119+
120+
while (sharedContext->workList.getCount() != 0)
121+
{
122+
// We will then iterate until our work list goes dry.
123+
//
124+
while (sharedContext->workList.getCount() != 0)
125+
{
126+
IRInst* inst = sharedContext->workList.getLast();
127+
128+
sharedContext->workList.removeLast();
129+
sharedContext->workListSet.Remove(inst);
130+
131+
processInst(inst);
132+
133+
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
134+
{
135+
sharedContext->addToWorkList(child);
136+
}
137+
}
138+
}
139+
}
140+
};
141+
142+
143+
void lowerExistentials(SharedGenericsLoweringContext* sharedContext)
144+
{
145+
ExistentialLoweringContext context;
146+
context.sharedContext = sharedContext;
147+
context.processModule();
148+
}
149+
}
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// slang-ir-lower-existential.h
2+
#pragma once
3+
4+
namespace Slang
5+
{
6+
struct SharedGenericsLoweringContext;
7+
8+
/// Lower existential types and related instructions to Tuple types.
9+
void lowerExistentials(
10+
SharedGenericsLoweringContext* sharedContext);
11+
12+
}
13+

source/slang/slang-ir-lower-generic-call.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// slang-ir-lower-generic-function.cpp
2-
#include "slang-ir-lower-generic-function.h"
1+
// slang-ir-lower-generic-call.cpp
2+
#include "slang-ir-lower-generic-call.h"
33
#include "slang-ir-generics-lowering-context.h"
44

55
namespace Slang
@@ -203,7 +203,7 @@ namespace Slang
203203
else if (auto lookupInst = as<IRLookupWitnessMethod>(callInst->getCallee()))
204204
lowerCallToInterfaceMethod(callInst, lookupInst);
205205
}
206-
206+
207207
void processInst(IRInst* inst)
208208
{
209209
if (auto callInst = as<IRCall>(inst))

source/slang/slang-ir-lower-generic-function.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ namespace Slang
167167
}
168168
}
169169
loweredType = builder.createInterfaceType(newEntries.getCount(), (IRInst**)newEntries.getBuffer());
170-
interfaceType->transferDecorationsTo(loweredType);
171-
interfaceType->replaceUsesWith(loweredType);
170+
IRCloneEnv cloneEnv;
171+
cloneInstDecorationsAndChildren(&cloneEnv, &sharedContext->sharedBuilderStorage,
172+
interfaceType, loweredType);
172173
sharedContext->loweredInterfaceTypes.Add(interfaceType, loweredType);
173174
sharedContext->mapLoweredInterfaceToOriginal[loweredType] = interfaceType;
174175
return loweredType;
@@ -272,6 +273,16 @@ namespace Slang
272273
}
273274
}
274275

276+
void replaceLoweredInterfaceTypes()
277+
{
278+
for (auto lowered : sharedContext->loweredInterfaceTypes)
279+
{
280+
lowered.Key->replaceUsesWith(lowered.Value);
281+
}
282+
// Update hash keys of globalNumberingMap, since the types are modified.
283+
sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
284+
}
285+
275286
void processModule()
276287
{
277288
// We start by initializing our shared IR building state,
@@ -303,6 +314,8 @@ namespace Slang
303314
}
304315
}
305316
}
317+
318+
replaceLoweredInterfaceTypes();
306319
}
307320
};
308321
void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext)

source/slang/slang-ir-lower-generic-function.h

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ namespace Slang
77

88
/// Lower generic and interface-based code to ordinary types and functions using
99
/// dynamic dispatch mechanisms.
10+
/// After this pass, generic type parameters will be lowered into `AnyValue` types,
11+
/// and an existential type I in function signatures will be lowered into
12+
/// `Tuple<AnyValue, WintessTable(I), RTTI*>`.
13+
/// Note that this pass mostly deals with function signatures and interface definitions,
14+
/// and does not modify function bodies.
15+
/// All variable declarations and type uses are handled in `lower-generic-type`,
16+
/// and all call sites are handled in `lower-generic-call`.
1017
void lowerGenericFunctions(
1118
SharedGenericsLoweringContext* sharedContext);
1219

source/slang/slang-ir-lower-generic-type.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ namespace Slang
1616

1717
void processInst(IRInst* inst)
1818
{
19+
// If inst is a type itself, keep its type.
20+
if (as<IRType>(inst))
21+
return;
22+
1923
IRBuilder builderStorage;
2024
auto builder = &builderStorage;
2125
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
@@ -57,6 +61,7 @@ namespace Slang
5761
}
5862
}
5963
}
64+
sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
6065
}
6166
};
6267

source/slang/slang-ir-lower-generic-type.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ namespace Slang
55
{
66
struct SharedGenericsLoweringContext;
77

8-
/// Lower all references to generic types (ThisType, AssociatedType, etc.) into IRAnyValueType.
8+
/// Lower all references to generic types (ThisType, AssociatedType, etc.) into IRAnyValueType,
9+
/// and existential types into Tuple<AnyValue, WitnessTable(I), Ptr(RTTIType)>.
910
void lowerGenericType(
1011
SharedGenericsLoweringContext* sharedContext);
1112

source/slang/slang-ir-lower-generics.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "slang-ir-any-value-marshalling.h"
55
#include "slang-ir-generics-lowering-context.h"
6+
#include "slang-ir-lower-existential.h"
67
#include "slang-ir-lower-generic-function.h"
78
#include "slang-ir-lower-generic-call.h"
89
#include "slang-ir-lower-generic-type.h"
@@ -20,11 +21,29 @@ namespace Slang
2021
sharedContext.module = module;
2122
sharedContext.sink = sink;
2223

24+
lowerExistentials(&sharedContext);
25+
if (sink->getErrorCount() != 0)
26+
return;
27+
2328
lowerGenericFunctions(&sharedContext);
29+
if (sink->getErrorCount() != 0)
30+
return;
31+
2432
lowerGenericType(&sharedContext);
33+
if (sink->getErrorCount() != 0)
34+
return;
35+
2536
lowerGenericCalls(&sharedContext);
37+
if (sink->getErrorCount() != 0)
38+
return;
39+
2640
generateWitnessTableWrapperFunctions(&sharedContext);
41+
if (sink->getErrorCount() != 0)
42+
return;
43+
2744
generateAnyValueMarshallingFunctions(&sharedContext);
45+
if (sink->getErrorCount() != 0)
46+
return;
2847
// We might have generated new temporary variables during lowering.
2948
// An SSA pass can clean up unnecessary load/stores.
3049
constructSSA(module);

source/slang/slang.vcxproj

+2
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@
241241
<ClInclude Include="slang-ir-layout.h" />
242242
<ClInclude Include="slang-ir-legalize-varying-params.h" />
243243
<ClInclude Include="slang-ir-link.h" />
244+
<ClInclude Include="slang-ir-lower-existential.h" />
244245
<ClInclude Include="slang-ir-lower-generic-call.h" />
245246
<ClInclude Include="slang-ir-lower-generic-function.h" />
246247
<ClInclude Include="slang-ir-lower-generic-type.h" />
@@ -345,6 +346,7 @@
345346
<ClCompile Include="slang-ir-legalize-types.cpp" />
346347
<ClCompile Include="slang-ir-legalize-varying-params.cpp" />
347348
<ClCompile Include="slang-ir-link.cpp" />
349+
<ClCompile Include="slang-ir-lower-existential.cpp" />
348350
<ClCompile Include="slang-ir-lower-generic-call.cpp" />
349351
<ClCompile Include="slang-ir-lower-generic-function.cpp" />
350352
<ClCompile Include="slang-ir-lower-generic-type.cpp" />

source/slang/slang.vcxproj.filters

+6
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@
174174
<ClInclude Include="slang-ir-link.h">
175175
<Filter>Header Files</Filter>
176176
</ClInclude>
177+
<ClInclude Include="slang-ir-lower-existential.h">
178+
<Filter>Header Files</Filter>
179+
</ClInclude>
177180
<ClInclude Include="slang-ir-lower-generic-call.h">
178181
<Filter>Header Files</Filter>
179182
</ClInclude>
@@ -482,6 +485,9 @@
482485
<ClCompile Include="slang-ir-link.cpp">
483486
<Filter>Source Files</Filter>
484487
</ClCompile>
488+
<ClCompile Include="slang-ir-lower-existential.cpp">
489+
<Filter>Source Files</Filter>
490+
</ClCompile>
485491
<ClCompile Include="slang-ir-lower-generic-call.cpp">
486492
<Filter>Source Files</Filter>
487493
</ClCompile>
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code
2+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -allow-dynamic-code
3+
4+
// Test dynamic dispatch code gen for extential type parameters.
5+
6+
[anyValueSize(16)]
7+
interface IInterface
8+
{
9+
int Compute(int inVal);
10+
};
11+
12+
int GenericCompute(IInterface obj, int inVal)
13+
{
14+
return obj.Compute(inVal);
15+
}
16+
17+
struct Impl : IInterface
18+
{
19+
int base;
20+
int Compute(int inVal) { return base + inVal * inVal; }
21+
};
22+
23+
int test(int inVal)
24+
{
25+
Impl obj;
26+
obj.base = 1;
27+
return GenericCompute(obj, inVal);
28+
}
29+
30+
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
31+
RWStructuredBuffer<int> outputBuffer : register(u0);
32+
33+
[numthreads(4, 1, 1)]
34+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
35+
{
36+
uint tid = dispatchThreadID.x;
37+
int inVal = outputBuffer[tid];
38+
int outVal = test(inVal);
39+
outputBuffer[tid] = outVal;
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1
2+
2
3+
5
4+
A

0 commit comments

Comments
 (0)