File tree 4 files changed +78
-13
lines changed
4 files changed +78
-13
lines changed Original file line number Diff line number Diff line change @@ -237,18 +237,15 @@ List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWi
237
237
for (UInt i = 0 ; i < interfaceType->getOperandCount (); i++)
238
238
{
239
239
auto reqKey = cast<IRStructKey>(interfaceType->getOperand (i));
240
- bool matchingEntryFound = false ;
241
240
IRWitnessTableEntry* entry = nullptr ;
242
241
if (witnessTableEntryDictionary.TryGetValue (reqKey, entry))
243
242
{
244
- if (entry->requirementKey .get () == reqKey)
245
- {
246
- matchingEntryFound = true ;
247
- sortedWitnessTableEntries.add (entry);
248
- break ;
249
- }
243
+ sortedWitnessTableEntries.add (entry);
244
+ }
245
+ else
246
+ {
247
+ SLANG_UNREACHABLE (" interface requirement key not found in witness table." );
250
248
}
251
- SLANG_ASSERT (matchingEntryFound);
252
249
}
253
250
return sortedWitnessTableEntries;
254
251
}
Original file line number Diff line number Diff line change @@ -1711,6 +1711,15 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
1711
1711
m_writer->emit (" &KernelContext::" );
1712
1712
m_writer->emit (_getWitnessTableWrapperFuncName (funcVal));
1713
1713
}
1714
+ else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal ()))
1715
+ {
1716
+ if (!isFirstEntry)
1717
+ m_writer->emit (" ,\n " );
1718
+ else
1719
+ isFirstEntry = false ;
1720
+ m_writer->emit (" &" );
1721
+ m_writer->emit (getName (witnessTableVal));
1722
+ }
1714
1723
else
1715
1724
{
1716
1725
// TODO: handle other witness table entry types.
@@ -1745,16 +1754,11 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
1745
1754
emitSimpleType (interfaceType);
1746
1755
m_writer->emit (" \n {\n " );
1747
1756
m_writer->indent ();
1748
- bool isFirstEntry = true ;
1749
1757
for (Index i = 0 ; i < sortedWitnessTableEntries.getCount (); i++)
1750
1758
{
1751
1759
auto entry = sortedWitnessTableEntries[i];
1752
1760
if (auto funcVal = as<IRFunc>(entry->satisfyingVal .get ()))
1753
1761
{
1754
- if (!isFirstEntry)
1755
- m_writer->emit (" ,\n " );
1756
- else
1757
- isFirstEntry = false ;
1758
1762
emitType (funcVal->getResultType ());
1759
1763
m_writer->emit (" (KernelContext::*" );
1760
1764
m_writer->emit (getName (entry->requirementKey .get ()));
@@ -1777,6 +1781,13 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
1777
1781
}
1778
1782
m_writer->emit (" );\n " );
1779
1783
}
1784
+ else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal ()))
1785
+ {
1786
+ emitType (as<IRType>(witnessTableVal->getOperand (0 )));
1787
+ m_writer->emit (" * " );
1788
+ m_writer->emit (getName (entry->requirementKey .get ()));
1789
+ m_writer->emit (" ;\n " );
1790
+ }
1780
1791
else
1781
1792
{
1782
1793
// TODO: handle other witness table entry types.
Original file line number Diff line number Diff line change
1
+ // TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code
2
+
3
+ // Test dynamic dispatch code gen for static member functions
4
+ // of associated type.
5
+
6
+ interface IAssoc
7
+ {
8
+ int get();
9
+ static int getBase();
10
+ }
11
+ interface IInterface
12
+ {
13
+ associatedtype Assoc : IAssoc;
14
+ int Compute(int inVal);
15
+ Assoc getAssoc();
16
+ };
17
+
18
+ int GenericCompute< T:IInterface> (T obj, int inVal)
19
+ {
20
+ return obj .Compute (inVal) + T .Assoc .getBase ();
21
+ }
22
+
23
+ struct Impl : IInterface
24
+ {
25
+ struct Assoc : IAssoc
26
+ {
27
+ int val;
28
+ int get() { return val; }
29
+ static int getBase() { return - 1 ; }
30
+ };
31
+ int base;
32
+ int Compute(int inVal) { return base + inVal * inVal; }
33
+ Assoc getAssoc() { Assoc rs; rs .val = 1 ; return rs; }
34
+ };
35
+
36
+ int test(int inVal)
37
+ {
38
+ Impl obj;
39
+ obj .base = 1 ;
40
+ return GenericCompute< Impl> (obj, inVal);
41
+ }
42
+
43
+ // TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
44
+ RWStructuredBuffer < int > outputBuffer : register (u0);
45
+
46
+ [numthreads(4 , 1 , 1 )]
47
+ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
48
+ {
49
+ uint tid = dispatchThreadID .x ;
50
+ int inVal = outputBuffer [tid];
51
+ int outVal = test(inVal);
52
+ outputBuffer [tid] = outVal;
53
+ }
Original file line number Diff line number Diff line change
1
+ 0
2
+ 1
3
+ 4
4
+ 9
You can’t perform that action at this time.
0 commit comments