Skip to content

Commit 444ff4d

Browse files
authored
Specialize witness table lookups. (shader-slang#1596)
* Specialize witness table lookups. * Remove generated files from vcxproj * Fix call to generic interface methods.
1 parent 94861d5 commit 444ff4d

15 files changed

+405
-23
lines changed

source/slang/slang-ir-generics-lowering-context.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,21 @@ namespace Slang
219219
}
220220
}
221221

222+
List<IRWitnessTable*> SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType(IRInst* interfaceType)
223+
{
224+
List<IRWitnessTable*> witnessTables;
225+
for (auto globalInst : module->getGlobalInsts())
226+
{
227+
if (globalInst->op == kIROp_WitnessTable &&
228+
cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() ==
229+
interfaceType)
230+
{
231+
witnessTables.add(cast<IRWitnessTable>(globalInst));
232+
}
233+
}
234+
return witnessTables;
235+
}
236+
222237
IRIntegerValue SharedGenericsLoweringContext::getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLocation)
223238
{
224239
if (auto decor = type->findDecoration<IRAnyValueSizeDecoration>())

source/slang/slang-ir-generics-lowering-context.h

+13
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ namespace Slang
7676
{
7777
return lowerType(builder, paramType, Dictionary<IRInst*, IRInst*>());
7878
}
79+
80+
// Get a list of all witness tables whose conformance type is `interfaceType`.
81+
List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRInst* interfaceType);
82+
83+
IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
84+
{
85+
for (auto entry : table->getEntries())
86+
{
87+
if (entry->getRequirementKey() == key)
88+
return entry->getSatisfyingVal();
89+
}
90+
return nullptr;
91+
}
7992
};
8093

8194
bool isPolymorphicType(IRInst* typeInst);

source/slang/slang-ir-inst-defs.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,12 @@ INST(TypeType, type_t, 0, 0)
187187
// An `IRWitnessTable` has type `WitnessTableType`.
188188
INST(WitnessTableType, witness_table_t, 1, 0)
189189

190-
INST_RANGE(Type, VoidType, WitnessTableType)
190+
// An integer type representing a witness table for targets where
191+
// witness tables are represented as integer IDs. This type is used
192+
// during the lower-generics pass while generating dynamic dispatch
193+
// code and will eventually lower into an uint type.
194+
INST(WitnessTableIDType, witness_table_id_t, 1, 0)
195+
INST_RANGE(Type, VoidType, WitnessTableIDType)
191196

192197
/*IRGlobalValueWithCode*/
193198
/* IRGlobalValueWithParams*/

source/slang/slang-ir-insts.h

+1
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,7 @@ struct IRBuilder
18171817

18181818
IRBasicBlockType* getBasicBlockType();
18191819
IRWitnessTableType* getWitnessTableType(IRType* baseType);
1820+
IRWitnessTableIDType* getWitnessTableIDType(IRType* baseType);
18201821
IRType* getTypeType() { return getType(IROp::kIROp_TypeType); }
18211822
IRType* getKeyType() { return nullptr; }
18221823

source/slang/slang-ir-lower-generic-call.cpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,17 @@ namespace Slang
266266
return;
267267
SLANG_UNEXPECTED("Nested generics specialization.");
268268
}
269+
else if (loweredFunc->op == kIROp_lookup_interface_method)
270+
{
271+
lowerCallToInterfaceMethod(
272+
callInst, cast<IRLookupWitnessMethod>(loweredFunc), specializeInst);
273+
return;
274+
}
269275
IRFuncType* funcType = cast<IRFuncType>(loweredFunc->getDataType());
270276
translateCallInst(callInst, funcType, loweredFunc, specializeInst);
271277
}
272278

273-
void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst)
279+
void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst, IRSpecialize* specializeInst)
274280
{
275281
// If we see a call(lookup_interface_method(...), ...), we need to translate
276282
// all occurences of associatedtypes.
@@ -312,15 +318,18 @@ namespace Slang
312318
// Translate the new call inst as normal, taking care of packing/unpacking inputs
313319
// and outputs.
314320
translateCallInst(
315-
newCall, cast<IRFuncType>(dispatchFunc->getFullType()), dispatchFunc, nullptr);
321+
newCall,
322+
cast<IRFuncType>(dispatchFunc->getFullType()),
323+
dispatchFunc,
324+
specializeInst);
316325
}
317326

318327
void lowerCall(IRCall* callInst)
319328
{
320329
if (auto specializeInst = as<IRSpecialize>(callInst->getCallee()))
321330
lowerCallToSpecializedFunc(callInst, specializeInst);
322331
else if (auto lookupInst = as<IRLookupWitnessMethod>(callInst->getCallee()))
323-
lowerCallToInterfaceMethod(callInst, lookupInst);
332+
lowerCallToInterfaceMethod(callInst, lookupInst, nullptr);
324333
}
325334

326335
void processInst(IRInst* inst)

source/slang/slang-ir-lower-generics.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "slang-ir-lower-generic-call.h"
1010
#include "slang-ir-lower-generic-type.h"
1111
#include "slang-ir-specialize-dispatch.h"
12+
#include "slang-ir-specialize-dynamic-associatedtype-lookup.h"
1213
#include "slang-ir-witness-table-wrapper.h"
1314
#include "slang-ir-ssa.h"
1415
#include "slang-ir-dce.h"
@@ -63,6 +64,10 @@ namespace Slang
6364
if (sink->getErrorCount() != 0)
6465
return;
6566

67+
specializeDynamicAssociatedTypeLookup(&sharedContext);
68+
if (sink->getErrorCount() != 0)
69+
return;
70+
6671
// We might have generated new temporary variables during lowering.
6772
// An SSA pass can clean up unnecessary load/stores.
6873
constructSSA(module);

source/slang/slang-ir-specialize-dispatch.cpp

+3-19
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,13 @@
66

77
namespace Slang
88
{
9-
IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
10-
{
11-
for (auto entry : table->getEntries())
12-
{
13-
if (entry->getRequirementKey() == key)
14-
return entry->getSatisfyingVal();
15-
}
16-
return nullptr;
17-
}
18-
199
IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
2010
{
2111
auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
2212

2313
// Collect all witness tables of `witnessTableType` in current module.
24-
List<IRWitnessTable*> witnessTables;
25-
for (auto globalInst : sharedContext->module->getGlobalInsts())
26-
{
27-
if (globalInst->op == kIROp_WitnessTable && globalInst->getDataType() == witnessTableType)
28-
{
29-
witnessTables.add(cast<IRWitnessTable>(globalInst));
30-
}
31-
}
14+
List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(
15+
cast<IRWitnessTableType>(witnessTableType)->getConformanceType());
3216

3317
SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
3418
auto block = dispatchFunc->getFirstBlock();
@@ -119,7 +103,7 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
119103
builder->setInsertInto(defaultBlock);
120104
}
121105

122-
auto callee = findWitnessTableEntry(witnessTable, requirementKey);
106+
auto callee = sharedContext->findWitnessTableEntry(witnessTable, requirementKey);
123107
SLANG_ASSERT(callee);
124108
auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
125109
if (callInst->getDataType()->op == kIROp_VoidType)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#include "slang-ir-specialize-dispatch.h"
2+
3+
#include "slang-ir-generics-lowering-context.h"
4+
#include "slang-ir-insts.h"
5+
#include "slang-ir.h"
6+
7+
namespace Slang
8+
{
9+
10+
struct AssociatedTypeLookupSpecializationContext
11+
{
12+
SharedGenericsLoweringContext* sharedContext;
13+
14+
IRFunc* createWitnessTableLookupFunc(IRInterfaceType* interfaceType, IRInst* key)
15+
{
16+
IRBuilder builder;
17+
builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
18+
builder.setInsertBefore(interfaceType);
19+
20+
auto inputWitnessTableIDType = builder.getWitnessTableIDType(interfaceType);
21+
auto requirementEntry = sharedContext->findInterfaceRequirementVal(interfaceType, key);
22+
23+
auto resultWitnessTableType = cast<IRWitnessTableType>(requirementEntry);
24+
auto resultWitnessTableIDType =
25+
builder.getWitnessTableIDType((IRType*)resultWitnessTableType->getConformanceType());
26+
27+
auto funcType =
28+
builder.getFuncType(1, (IRType**)&inputWitnessTableIDType, resultWitnessTableIDType);
29+
auto func = builder.createFunc();
30+
func->setFullType(funcType);
31+
32+
if (auto linkage = key->findDecoration<IRLinkageDecoration>())
33+
builder.addNameHintDecoration(func, linkage->getMangledName());
34+
35+
builder.setInsertInto(func);
36+
37+
auto block = builder.emitBlock();
38+
auto witnessTableParam = builder.emitParam(inputWitnessTableIDType);
39+
40+
// Collect all witness tables of `witnessTableType` in current module.
41+
List<IRWitnessTable*> witnessTables =
42+
sharedContext->getWitnessTablesFromInterfaceType(interfaceType);
43+
44+
// Generate case blocks for each possible witness table.
45+
IRBlock* defaultBlock = nullptr;
46+
List<IRInst*> caseBlocks;
47+
for (Index i = 0; i < witnessTables.getCount(); i++)
48+
{
49+
auto witnessTable = witnessTables[i];
50+
auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>();
51+
SLANG_ASSERT(seqIdDecoration);
52+
53+
if (i != witnessTables.getCount() - 1)
54+
{
55+
// Create a case block if we are not the last case.
56+
caseBlocks.add(seqIdDecoration->getSequentialIDOperand());
57+
builder.setInsertInto(func);
58+
auto caseBlock = builder.emitBlock();
59+
caseBlocks.add(caseBlock);
60+
}
61+
else
62+
{
63+
// Generate code for the last possible value in the `default` block.
64+
builder.setInsertInto(func);
65+
defaultBlock = builder.emitBlock();
66+
builder.setInsertInto(defaultBlock);
67+
}
68+
69+
auto resultWitnessTable = sharedContext->findWitnessTableEntry(witnessTable, key);
70+
auto resultWitnessTableIDDecoration =
71+
resultWitnessTable->findDecoration<IRSequentialIDDecoration>();
72+
SLANG_ASSERT(resultWitnessTableIDDecoration);
73+
builder.emitReturn(resultWitnessTableIDDecoration->getSequentialIDOperand());
74+
}
75+
76+
// Emit a switch statement to return the correct witness table ID based on
77+
// the witness table ID passed in.
78+
builder.setInsertInto(func);
79+
auto breakBlock = builder.emitBlock();
80+
builder.setInsertInto(breakBlock);
81+
builder.emitUnreachable();
82+
83+
builder.setInsertInto(block);
84+
builder.emitSwitch(
85+
witnessTableParam,
86+
breakBlock,
87+
defaultBlock,
88+
caseBlocks.getCount(),
89+
caseBlocks.getBuffer());
90+
91+
return func;
92+
}
93+
94+
// Retrieves the conformance type from a WitnessTableType or a WitnessTableIDType.
95+
IRInterfaceType* getInterfaceTypeFromWitnessTableTypes(IRInst* witnessTableType)
96+
{
97+
switch (witnessTableType->op)
98+
{
99+
case kIROp_WitnessTableType:
100+
return cast<IRInterfaceType>(
101+
cast<IRWitnessTableType>(witnessTableType)->getConformanceType());
102+
case kIROp_WitnessTableIDType:
103+
return cast<IRInterfaceType>(
104+
cast<IRWitnessTableIDType>(witnessTableType)->getConformanceType());
105+
default:
106+
return nullptr;
107+
}
108+
}
109+
110+
void processLookupInterfaceMethodInst(IRLookupWitnessMethod* inst)
111+
{
112+
// Ignore lookups for RTTI objects for now, since they are not used anywhere.
113+
if (!as<IRWitnessTableType>(inst->getDataType()))
114+
return;
115+
116+
// Replace all witness table lookups with calls to specialized functions that directly
117+
// returns the sequential ID of the resulting witness table, effectively getting rid
118+
// of actual witness table objects in the target code (they all become IDs).
119+
auto witnessTableType = inst->getWitnessTable()->getDataType();
120+
IRInterfaceType* interfaceType = getInterfaceTypeFromWitnessTableTypes(witnessTableType);
121+
if (!interfaceType)
122+
return;
123+
auto key = inst->getRequirementKey();
124+
IRFunc* func = nullptr;
125+
if (!sharedContext->mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(key, func))
126+
{
127+
func = createWitnessTableLookupFunc(interfaceType, key);
128+
sharedContext->mapInterfaceRequirementKeyToDispatchMethods[key] = func;
129+
}
130+
IRBuilder builder;
131+
builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
132+
builder.setInsertBefore(inst);
133+
auto witnessTableArg = inst->getWitnessTable();
134+
if (witnessTableArg->getDataType()->op == kIROp_WitnessTableType)
135+
{
136+
witnessTableArg = builder.emitGetSequentialIDInst(witnessTableArg);
137+
}
138+
auto callInst = builder.emitCallInst(
139+
builder.getWitnessTableIDType(interfaceType), func, witnessTableArg);
140+
inst->replaceUsesWith(callInst);
141+
inst->removeAndDeallocate();
142+
}
143+
144+
void cleanUpWitnessTableIDType()
145+
{
146+
List<IRInst*> instsToRemove;
147+
for (auto inst : sharedContext->module->getGlobalInsts())
148+
{
149+
if (inst->op == kIROp_WitnessTableIDType)
150+
{
151+
IRBuilder builder;
152+
builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
153+
builder.setInsertBefore(inst);
154+
inst->replaceUsesWith(builder.getUIntType());
155+
instsToRemove.add(inst);
156+
}
157+
}
158+
for (auto inst : instsToRemove)
159+
inst->removeAndDeallocate();
160+
}
161+
162+
void processGetSequentialIDInst(IRGetSequentialID* inst)
163+
{
164+
if (inst->getRTTIOperand()->getDataType()->op == kIROp_WitnessTableIDType)
165+
{
166+
inst->replaceUsesWith(inst->getRTTIOperand());
167+
inst->removeAndDeallocate();
168+
}
169+
}
170+
171+
void processModule()
172+
{
173+
SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
174+
sharedBuilder->module = sharedContext->module;
175+
sharedBuilder->session = sharedContext->module->session;
176+
177+
sharedContext->addToWorkList(sharedContext->module->getModuleInst());
178+
179+
while (sharedContext->workList.getCount() != 0)
180+
{
181+
IRInst* inst = sharedContext->workList.getLast();
182+
183+
sharedContext->workList.removeLast();
184+
sharedContext->workListSet.Remove(inst);
185+
186+
if (inst->op == kIROp_lookup_interface_method)
187+
{
188+
processLookupInterfaceMethodInst(cast<IRLookupWitnessMethod>(inst));
189+
}
190+
191+
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
192+
{
193+
sharedContext->addToWorkList(child);
194+
}
195+
}
196+
197+
// `GetSequentialID(WitnessTableIDOperand)` becomes just `WitnessTableIDOperand`.
198+
sharedContext->addToWorkList(sharedContext->module->getModuleInst());
199+
while (sharedContext->workList.getCount() != 0)
200+
{
201+
IRInst* inst = sharedContext->workList.getLast();
202+
203+
sharedContext->workList.removeLast();
204+
sharedContext->workListSet.Remove(inst);
205+
206+
if (inst->op == kIROp_GetSequentialID)
207+
{
208+
processGetSequentialIDInst(cast<IRGetSequentialID>(inst));
209+
}
210+
211+
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
212+
{
213+
sharedContext->addToWorkList(child);
214+
}
215+
}
216+
217+
cleanUpWitnessTableIDType();
218+
}
219+
};
220+
221+
void specializeDynamicAssociatedTypeLookup(SharedGenericsLoweringContext* sharedContext)
222+
{
223+
AssociatedTypeLookupSpecializationContext context;
224+
context.sharedContext = sharedContext;
225+
context.processModule();
226+
}
227+
228+
} // namespace Slang

0 commit comments

Comments
 (0)