Skip to content

Commit 3d4eaf3

Browse files
committed
Support transitive interfaces
This commit is a bunch of quick hacks to get transitive interfaces to work. The idea is for each concrete type we create one giant witness table that contains entries for all the transitively reachable interface requirements, and then create one copy of that witness table for each interface it implements. `DoLocalLookupImpl` now also looks up in inherited interface decles when looking up for a symbol in an interface decl. When visiting `InheritanceDecl` in `lower-to-ir`, create copies of the giant witness table for each transitively inherited interface, so that these witness tables can be found later when the IR is specialized. Re-enable the `copy all witness tables` hack in `specializeIRForEntryPoint` to ensure those transitive witness tables are copied over.
1 parent 8abae05 commit 3d4eaf3

File tree

5 files changed

+118
-2
lines changed

5 files changed

+118
-2
lines changed

source/slang/ir.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -3690,12 +3690,12 @@ namespace Slang
36903690
cloneFunctionCommon(context, clonedFunc, originalFunc);
36913691

36923692
// for now, clone all unreferenced witness tables
3693-
/*for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
3693+
for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
36943694
gv; gv = gv->getNextValue())
36953695
{
36963696
if (gv->op == kIROp_witness_table)
36973697
cloneGlobalValue(context, (IRWitnessTable*)gv);
3698-
}*/
3698+
}
36993699

37003700
// We need to attach the layout information for
37013701
// the entry point to this declaration, so that
@@ -4746,7 +4746,9 @@ namespace Slang
47464746
//
47474747
// We will first find or construct a specialized version
47484748
// of the callee funciton/
4749+
auto oldFunc = dumpIRFunc(genericFunc);
47494750
auto specFunc = getSpecializedFunc(sharedContext, genericFunc, specDeclRef);
4751+
auto newFunc = dumpIRFunc(specFunc);
47504752
//
47514753
// Then we will replace the use sites for the `specialize`
47524754
// instruction with uses of the specialized function.

source/slang/lookup.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ void DoLocalLookupImpl(
297297
session,
298298
name, extDeclRef, request, result, inBreadcrumbs);
299299
}
300+
301+
}
302+
// for interface decls, also lookup in the base interfaces
303+
if (request.semantics)
304+
{
305+
if (auto interfaceDeclRef = containerDeclRef.As<InterfaceDecl>())
306+
{
307+
auto baseInterfaces = getMembersOfType<InheritanceDecl>(interfaceDeclRef);
308+
for (auto inheritanceDeclRef : baseInterfaces)
309+
{
310+
auto baseType = inheritanceDeclRef.getDecl()->base.type.As<DeclRefType>();
311+
SLANG_ASSERT(baseType);
312+
int diff = 0;
313+
auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(interfaceDeclRef.substitutions, &diff);
314+
DoLocalLookupImpl(session, name, baseInterfaceDeclRef.As<ContainerDecl>(), request, result, inBreadcrumbs);
315+
}
316+
}
300317
}
301318
}
302319

source/slang/lower-to-ir.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -2781,6 +2781,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
27812781
return LoweredValInfo();
27822782
}
27832783

2784+
void walkInheritanceHierarchyAndCreateWitnessTableCopies(IRWitnessTable* witnessTable, Type* subType, InheritanceDecl* inheritanceDecl)
2785+
{
2786+
auto baseDeclRef = inheritanceDecl->base.type.As<DeclRefType>();
2787+
if (auto baseInterfaceDeclRef = baseDeclRef->declRef.As<InterfaceDecl>())
2788+
{
2789+
for (auto subInheritanceDeclRef : getMembersOfType<InheritanceDecl>(baseInterfaceDeclRef))
2790+
{
2791+
auto cpyMangledName = getMangledNameForConformanceWitness(subType, subInheritanceDeclRef.getDecl()->getSup().type);
2792+
if (!witnessTablesDictionary.ContainsKey(cpyMangledName))
2793+
{
2794+
auto cpyTable = context->irBuilder->createWitnessTable();
2795+
cpyTable->mangledName = cpyMangledName;
2796+
context->irBuilder->createWitnessTableEntry(witnessTable,
2797+
context->irBuilder->getDeclRefVal(subInheritanceDeclRef), cpyTable);
2798+
cpyTable->entries = witnessTable->entries;
2799+
witnessTablesDictionary.Add(cpyMangledName, cpyTable);
2800+
walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, subType, subInheritanceDeclRef.getDecl());
2801+
}
2802+
}
2803+
}
2804+
}
2805+
2806+
Dictionary<String, IRWitnessTable*> witnessTablesDictionary;
2807+
27842808
LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
27852809
{
27862810
// Construct a type for the parent declaration.
@@ -2817,6 +2841,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
28172841
// conformance of the type to its super-type.
28182842
auto witnessTable = context->irBuilder->createWitnessTable();
28192843
witnessTable->mangledName = mangledName;
2844+
2845+
witnessTablesDictionary.Add(mangledName, witnessTable);
28202846

28212847
if (parentDecl->ParentDecl)
28222848
witnessTable->genericDecl = dynamic_cast<GenericDecl*>(parentDecl->ParentDecl);
@@ -2850,6 +2876,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
28502876
}
28512877

28522878
witnessTable->moveToEnd();
2879+
walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, type, inheritanceDecl);
28532880

28542881
// A direct reference to this inheritance relationship (e.g.,
28552882
// as a subtype witness) will take the form of a reference to
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
interface IAdd
7+
{
8+
float addf(float u, float v);
9+
}
10+
11+
interface ISub
12+
{
13+
float subf(float u, float v);
14+
}
15+
16+
interface IAddAndSub : IAdd, ISub
17+
{
18+
}
19+
20+
struct Simple : IAddAndSub
21+
{
22+
float addf(float u, float v)
23+
{
24+
return u+v;
25+
}
26+
float subf(float u, float v)
27+
{
28+
return u-v;
29+
}
30+
};
31+
32+
float testAdd<T:IAdd>(T t)
33+
{
34+
return t.addf(1.0, 1.0);
35+
}
36+
37+
interface IAssoc
38+
{
39+
associatedtype AT : IAdd;
40+
}
41+
42+
struct AssocImpl : IAssoc
43+
{
44+
typedef Simple AT;
45+
};
46+
47+
float testAdd2<T:IAssoc>(T assoc)
48+
{
49+
T.AT obj;
50+
return obj.addf(1.0, 1.0);
51+
}
52+
53+
float testSub<T:ISub>(T t, float base)
54+
{
55+
return t.subf(base, 1.0);
56+
}
57+
58+
[numthreads(4, 1, 1)]
59+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
60+
{
61+
AssocImpl s;
62+
float outVal = testAdd2(s);
63+
Simple s1;
64+
outVal += testSub(s1, outVal);
65+
outputBuffer[dispatchThreadID.x] = outVal;
66+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
3F800000
3+
3F800000
4+
3F800000

0 commit comments

Comments
 (0)