Skip to content

Commit 68fd448

Browse files
authored
Merge pull request shader-slang#367 from csyonghe/extension2
Support transitive interfaces
2 parents 8abae05 + 513f56b commit 68fd448

File tree

5 files changed

+116
-2
lines changed

5 files changed

+116
-2
lines changed

source/slang/ir.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -3690,12 +3690,12 @@ namespace Slang
36903690
cloneFunctionCommon(context, clonedFunc, originalFunc);
36913691

36923692
// for now, clone all unreferenced witness tables
3693-
/*for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
3693+
for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
36943694
gv; gv = gv->getNextValue())
36953695
{
36963696
if (gv->op == kIROp_witness_table)
36973697
cloneGlobalValue(context, (IRWitnessTable*)gv);
3698-
}*/
3698+
}
36993699

37003700
// We need to attach the layout information for
37013701
// the entry point to this declaration, so that

source/slang/lookup.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ void DoLocalLookupImpl(
297297
session,
298298
name, extDeclRef, request, result, inBreadcrumbs);
299299
}
300+
301+
}
302+
// for interface decls, also lookup in the base interfaces
303+
if (request.semantics)
304+
{
305+
if (auto interfaceDeclRef = containerDeclRef.As<InterfaceDecl>())
306+
{
307+
auto baseInterfaces = getMembersOfType<InheritanceDecl>(interfaceDeclRef);
308+
for (auto inheritanceDeclRef : baseInterfaces)
309+
{
310+
auto baseType = inheritanceDeclRef.getDecl()->base.type.As<DeclRefType>();
311+
SLANG_ASSERT(baseType);
312+
int diff = 0;
313+
auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(interfaceDeclRef.substitutions, &diff);
314+
DoLocalLookupImpl(session, name, baseInterfaceDeclRef.As<ContainerDecl>(), request, result, inBreadcrumbs);
315+
}
316+
}
300317
}
301318
}
302319

source/slang/lower-to-ir.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -2781,6 +2781,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
27812781
return LoweredValInfo();
27822782
}
27832783

2784+
void walkInheritanceHierarchyAndCreateWitnessTableCopies(IRWitnessTable* witnessTable, Type* subType, InheritanceDecl* inheritanceDecl)
2785+
{
2786+
auto baseDeclRef = inheritanceDecl->base.type.As<DeclRefType>();
2787+
if (auto baseInterfaceDeclRef = baseDeclRef->declRef.As<InterfaceDecl>())
2788+
{
2789+
for (auto subInheritanceDeclRef : getMembersOfType<InheritanceDecl>(baseInterfaceDeclRef))
2790+
{
2791+
auto cpyMangledName = getMangledNameForConformanceWitness(subType, subInheritanceDeclRef.getDecl()->getSup().type);
2792+
if (!witnessTablesDictionary.ContainsKey(cpyMangledName))
2793+
{
2794+
auto cpyTable = context->irBuilder->createWitnessTable();
2795+
cpyTable->mangledName = cpyMangledName;
2796+
context->irBuilder->createWitnessTableEntry(witnessTable,
2797+
context->irBuilder->getDeclRefVal(subInheritanceDeclRef), cpyTable);
2798+
cpyTable->entries = witnessTable->entries;
2799+
witnessTablesDictionary.Add(cpyMangledName, cpyTable);
2800+
walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, subType, subInheritanceDeclRef.getDecl());
2801+
}
2802+
}
2803+
}
2804+
}
2805+
2806+
Dictionary<String, IRWitnessTable*> witnessTablesDictionary;
2807+
27842808
LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
27852809
{
27862810
// Construct a type for the parent declaration.
@@ -2817,6 +2841,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
28172841
// conformance of the type to its super-type.
28182842
auto witnessTable = context->irBuilder->createWitnessTable();
28192843
witnessTable->mangledName = mangledName;
2844+
2845+
witnessTablesDictionary.Add(mangledName, witnessTable);
28202846

28212847
if (parentDecl->ParentDecl)
28222848
witnessTable->genericDecl = dynamic_cast<GenericDecl*>(parentDecl->ParentDecl);
@@ -2850,6 +2876,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
28502876
}
28512877

28522878
witnessTable->moveToEnd();
2879+
walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, type, inheritanceDecl);
28532880

28542881
// A direct reference to this inheritance relationship (e.g.,
28552882
// as a subtype witness) will take the form of a reference to
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
interface IAdd
7+
{
8+
float addf(float u, float v);
9+
}
10+
11+
interface ISub
12+
{
13+
float subf(float u, float v);
14+
}
15+
16+
interface IAddAndSub : IAdd, ISub
17+
{
18+
}
19+
20+
struct Simple : IAddAndSub
21+
{
22+
float addf(float u, float v)
23+
{
24+
return u+v;
25+
}
26+
float subf(float u, float v)
27+
{
28+
return u-v;
29+
}
30+
};
31+
32+
float testAdd<T:IAdd>(T t)
33+
{
34+
return t.addf(1.0, 1.0);
35+
}
36+
37+
interface IAssoc
38+
{
39+
associatedtype AT : IAdd;
40+
}
41+
42+
struct AssocImpl : IAssoc
43+
{
44+
typedef Simple AT;
45+
};
46+
47+
float testAdd2<T:IAssoc>(T assoc)
48+
{
49+
T.AT obj;
50+
return obj.addf(1.0, 1.0);
51+
}
52+
53+
float testSub<T:ISub>(T t, float base)
54+
{
55+
return t.subf(base, 1.0);
56+
}
57+
58+
[numthreads(4, 1, 1)]
59+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
60+
{
61+
AssocImpl s;
62+
float outVal = testAdd2(s);
63+
Simple s1;
64+
outVal += testSub(s1, outVal);
65+
outputBuffer[dispatchThreadID.x] = outVal;
66+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
3F800000
3+
3F800000
4+
3F800000

0 commit comments

Comments
 (0)