Skip to content

Commit 110d15b

Browse files
authored
Dynamic dispatch for static member functions of associatedtypes. (shader-slang#1404)
1 parent 5fbb9ff commit 110d15b

File tree

4 files changed

+78
-13
lines changed

4 files changed

+78
-13
lines changed

source/slang/slang-emit-c-like.cpp

+5-8
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,15 @@ List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWi
237237
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
238238
{
239239
auto reqKey = cast<IRStructKey>(interfaceType->getOperand(i));
240-
bool matchingEntryFound = false;
241240
IRWitnessTableEntry* entry = nullptr;
242241
if (witnessTableEntryDictionary.TryGetValue(reqKey, entry))
243242
{
244-
if (entry->requirementKey.get() == reqKey)
245-
{
246-
matchingEntryFound = true;
247-
sortedWitnessTableEntries.add(entry);
248-
break;
249-
}
243+
sortedWitnessTableEntries.add(entry);
244+
}
245+
else
246+
{
247+
SLANG_UNREACHABLE("interface requirement key not found in witness table.");
250248
}
251-
SLANG_ASSERT(matchingEntryFound);
252249
}
253250
return sortedWitnessTableEntries;
254251
}

source/slang/slang-emit-cpp.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,15 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
17111711
m_writer->emit("&KernelContext::");
17121712
m_writer->emit(_getWitnessTableWrapperFuncName(funcVal));
17131713
}
1714+
else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal()))
1715+
{
1716+
if (!isFirstEntry)
1717+
m_writer->emit(",\n");
1718+
else
1719+
isFirstEntry = false;
1720+
m_writer->emit("&");
1721+
m_writer->emit(getName(witnessTableVal));
1722+
}
17141723
else
17151724
{
17161725
// TODO: handle other witness table entry types.
@@ -1745,16 +1754,11 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
17451754
emitSimpleType(interfaceType);
17461755
m_writer->emit("\n{\n");
17471756
m_writer->indent();
1748-
bool isFirstEntry = true;
17491757
for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++)
17501758
{
17511759
auto entry = sortedWitnessTableEntries[i];
17521760
if (auto funcVal = as<IRFunc>(entry->satisfyingVal.get()))
17531761
{
1754-
if (!isFirstEntry)
1755-
m_writer->emit(",\n");
1756-
else
1757-
isFirstEntry = false;
17581762
emitType(funcVal->getResultType());
17591763
m_writer->emit(" (KernelContext::*");
17601764
m_writer->emit(getName(entry->requirementKey.get()));
@@ -1777,6 +1781,13 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
17771781
}
17781782
m_writer->emit(");\n");
17791783
}
1784+
else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal()))
1785+
{
1786+
emitType(as<IRType>(witnessTableVal->getOperand(0)));
1787+
m_writer->emit("* ");
1788+
m_writer->emit(getName(entry->requirementKey.get()));
1789+
m_writer->emit(";\n");
1790+
}
17801791
else
17811792
{
17821793
// TODO: handle other witness table entry types.
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code
2+
3+
// Test dynamic dispatch code gen for static member functions
4+
// of associated type.
5+
6+
interface IAssoc
7+
{
8+
int get();
9+
static int getBase();
10+
}
11+
interface IInterface
12+
{
13+
associatedtype Assoc : IAssoc;
14+
int Compute(int inVal);
15+
Assoc getAssoc();
16+
};
17+
18+
int GenericCompute<T:IInterface>(T obj, int inVal)
19+
{
20+
return obj.Compute(inVal) + T.Assoc.getBase();
21+
}
22+
23+
struct Impl : IInterface
24+
{
25+
struct Assoc : IAssoc
26+
{
27+
int val;
28+
int get() { return val; }
29+
static int getBase() { return -1; }
30+
};
31+
int base;
32+
int Compute(int inVal) { return base + inVal * inVal; }
33+
Assoc getAssoc() { Assoc rs; rs.val = 1; return rs; }
34+
};
35+
36+
int test(int inVal)
37+
{
38+
Impl obj;
39+
obj.base = 1;
40+
return GenericCompute<Impl>(obj, inVal);
41+
}
42+
43+
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
44+
RWStructuredBuffer<int> outputBuffer : register(u0);
45+
46+
[numthreads(4, 1, 1)]
47+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
48+
{
49+
uint tid = dispatchThreadID.x;
50+
int inVal = outputBuffer[tid];
51+
int outVal = test(inVal);
52+
outputBuffer[tid] = outVal;
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
0
2+
1
3+
4
4+
9

0 commit comments

Comments
 (0)