Skip to content

Commit d33e6b7

Browse files
committed
allow extension of a concrete type to implement additional interface
Also support the scenario that the extension declares conformance to interface I, and a method M in I is already supported by the base implementation.
1 parent d4dab2c commit d33e6b7

7 files changed

+189
-29
lines changed

source/slang/check.cpp

+23-16
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,7 @@ namespace Slang
17941794
// `requiredMemberDeclRef` is a required member of
17951795
// the interface.
17961796
RefPtr<Decl> findWitnessForInterfaceRequirement(
1797-
DeclRef<AggTypeDecl> typeDeclRef,
1797+
DeclRef<AggTypeDeclBase> typeDeclRef,
17981798
InheritanceDecl* inheritanceDecl,
17991799
DeclRef<InterfaceDecl> interfaceDeclRef,
18001800
DeclRef<Decl> requiredMemberDeclRef,
@@ -1833,22 +1833,20 @@ namespace Slang
18331833

18341834
// Make sure that by-name lookup is possible.
18351835
buildMemberDictionary(typeDeclRef.getDecl());
1836-
1837-
Decl* firstMemberOfName = nullptr;
1838-
typeDeclRef.getDecl()->memberDictionary.TryGetValue(name, firstMemberOfName);
1839-
1840-
if (!firstMemberOfName)
1836+
auto lookupResult = lookUpLocal(getSession(), this, name, typeDeclRef);
1837+
1838+
if (!lookupResult.isValid())
18411839
{
18421840
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef);
18431841
return nullptr;
18441842
}
18451843

18461844
// Iterate over the members and look for one that matches
18471845
// the expected signature for the requirement.
1848-
for (auto memberDecl = firstMemberOfName; memberDecl; memberDecl = memberDecl->nextInContainerWithSameName)
1846+
for (auto member : lookupResult)
18491847
{
1850-
if (doesMemberSatisfyRequirement(DeclRef<Decl>(memberDecl, typeDeclRef.substitutions), requiredMemberDeclRef, requirementWitness))
1851-
return memberDecl;
1848+
if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, requirementWitness))
1849+
return member.declRef.getDecl();
18521850
}
18531851

18541852
// No suitable member found, although there were candidates.
@@ -1867,7 +1865,7 @@ namespace Slang
18671865
// (via the given `inheritanceDecl`) actually provides
18681866
// members to satisfy all the requirements in the interface.
18691867
bool checkInterfaceConformance(
1870-
DeclRef<AggTypeDecl> typeDeclRef,
1868+
DeclRef<AggTypeDeclBase> typeDeclRef,
18711869
InheritanceDecl* inheritanceDecl,
18721870
DeclRef<InterfaceDecl> interfaceDeclRef)
18731871
{
@@ -1925,7 +1923,7 @@ namespace Slang
19251923
}
19261924

19271925
bool checkConformanceToType(
1928-
DeclRef<AggTypeDecl> typeDeclRef,
1926+
DeclRef<AggTypeDeclBase> typeDeclRef,
19291927
InheritanceDecl* inheritanceDecl,
19301928
Type* baseType)
19311929
{
@@ -1953,7 +1951,7 @@ namespace Slang
19531951
// `inheritanceDecl` actually does what it needs to
19541952
// for that inheritance to be valid.
19551953
bool checkConformance(
1956-
DeclRef<AggTypeDecl> typeDecl,
1954+
DeclRef<AggTypeDeclBase> typeDecl,
19571955
InheritanceDecl* inheritanceDecl)
19581956
{
19591957
// Look at the type being inherited from, and validate
@@ -1963,10 +1961,10 @@ namespace Slang
19631961
}
19641962

19651963
bool checkConformance(
1966-
AggTypeDecl* typeDecl,
1964+
AggTypeDeclBase* typeDecl,
19671965
InheritanceDecl* inheritanceDecl)
19681966
{
1969-
return checkConformance(DeclRef<AggTypeDecl>(typeDecl, SubstitutionSet()), inheritanceDecl);
1967+
return checkConformance(DeclRef<AggTypeDeclBase>(typeDecl, SubstitutionSet()), inheritanceDecl);
19701968
}
19711969

19721970
void visitAggTypeDecl(AggTypeDecl* decl)
@@ -3479,10 +3477,11 @@ namespace Slang
34793477

34803478
// TODO: need to check that the target type names a declaration...
34813479

3480+
DeclRef<AggTypeDecl> aggTypeDeclRef;
34823481
if (auto targetDeclRefType = decl->targetType->As<DeclRefType>())
34833482
{
34843483
// Attach our extension to that type as a candidate...
3485-
if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>())
3484+
if (aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>())
34863485
{
34873486
auto aggTypeDecl = aggTypeDeclRef.getDecl();
34883487
decl->nextCandidateExtension = aggTypeDecl->candidateExtensions;
@@ -3516,6 +3515,14 @@ namespace Slang
35163515
EnsureDecl(m);
35173516
}
35183517

3518+
if (aggTypeDeclRef)
3519+
{
3520+
for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
3521+
{
3522+
checkConformance(aggTypeDeclRef.getDecl(), inheritanceDecl);
3523+
}
3524+
}
3525+
35193526
decl->SetCheckState(DeclCheckState::Checked);
35203527
}
35213528

@@ -3802,7 +3809,7 @@ namespace Slang
38023809

38033810
if( auto aggTypeDeclRef = declRef.As<AggTypeDecl>() )
38043811
{
3805-
for( auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(aggTypeDeclRef))
3812+
for( auto inheritanceDeclRef : getMembersOfTypeWithExt<InheritanceDecl>(aggTypeDeclRef))
38063813
{
38073814
EnsureDecl(inheritanceDeclRef.getDecl());
38083815

source/slang/lower-to-ir.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -2788,19 +2788,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
27882788
// TODO: if this inheritance declaration is under an extension,
27892789
// then we should construct the type that is being extended,
27902790
// and not a reference to the extension itself.
2791-
auto parentDecl = inheritanceDecl->ParentDecl;
2792-
RefPtr<Type> type = DeclRefType::Create(
2793-
context->getSession(),
2794-
makeDeclRef(parentDecl));
27952791

2792+
auto parentDecl = inheritanceDecl->ParentDecl;
2793+
RefPtr<Type> type;
2794+
if (auto extParentDecl = dynamic_cast<ExtensionDecl*>(parentDecl))
2795+
{
2796+
type = extParentDecl->targetType.type;
2797+
if (auto declRefType = type.As<DeclRefType>())
2798+
{
2799+
if (auto aggTypeDecl = declRefType->declRef.As<AggTypeDecl>())
2800+
parentDecl = aggTypeDecl.getDecl();
2801+
}
2802+
}
2803+
else
2804+
{
2805+
type = DeclRefType::Create(
2806+
context->getSession(),
2807+
makeDeclRef(parentDecl));
2808+
}
27962809
// What is the super-type that we have declared we inherit from?
27972810
RefPtr<Type> superType = inheritanceDecl->base.type;
27982811

27992812
// Construct the mangled name for the witness table, which depends
28002813
// on the type that is conforming, and the type that it conforms to.
2801-
String mangledName = getMangledNameForConformanceWitness(
2802-
makeDeclRef(parentDecl),
2803-
superType);
2814+
String mangledName = getMangledNameForConformanceWitness(type, superType);
28042815

28052816
// Build an IR level witness table, which will represent the
28062817
// conformance of the type to its super-type.

source/slang/syntax.h

+46-6
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,30 @@ namespace Slang
978978
{
979979
return items.Count() > 1 ? items[0].declRef.GetName() : item.declRef.GetName();
980980
}
981+
LookupResultItem* begin()
982+
{
983+
if (isValid())
984+
{
985+
if (isOverloaded())
986+
return items.begin();
987+
else
988+
return &item;
989+
}
990+
else
991+
return nullptr;
992+
}
993+
LookupResultItem* end()
994+
{
995+
if (isValid())
996+
{
997+
if (isOverloaded())
998+
return items.end();
999+
else
1000+
return &item + 1;
1001+
}
1002+
else
1003+
return nullptr;
1004+
}
9811005
};
9821006

9831007
struct SemanticsVisitor;
@@ -1085,6 +1109,27 @@ namespace Slang
10851109
return FilteredMemberRefList<T>(declRef.getDecl()->Members, declRef.substitutions);
10861110
}
10871111

1112+
inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
1113+
{
1114+
return declRef.getDecl()->candidateExtensions;
1115+
}
1116+
1117+
template<typename T>
1118+
inline FilteredMemberRefList<T> getMembersOfTypeWithExt(DeclRef<ContainerDecl> const& declRef)
1119+
{
1120+
auto rs = getMembersOfType<T>(declRef);
1121+
if (auto aggDeclRef = declRef.As<AggTypeDecl>())
1122+
{
1123+
for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension)
1124+
{
1125+
auto extMembers = getMembersOfType<T>(DeclRef<ContainerDecl>(ext, declRef.substitutions));
1126+
const_cast<List<RefPtr<Decl>>&>(rs.decls).AddRange(extMembers.decls);
1127+
}
1128+
}
1129+
return rs;
1130+
}
1131+
1132+
10881133
inline RefPtr<Type> GetType(DeclRef<VarDeclBase> const& declRef)
10891134
{
10901135
return declRef.Substitute(declRef.getDecl()->type.Ptr());
@@ -1099,12 +1144,7 @@ namespace Slang
10991144
{
11001145
return declRef.Substitute(declRef.getDecl()->targetType.Ptr());
11011146
}
1102-
1103-
inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
1104-
{
1105-
return declRef.getDecl()->candidateExtensions;
1106-
}
1107-
1147+
11081148
inline FilteredMemberRefList<StructField> GetFields(DeclRef<StructDecl> const& declRef)
11091149
{
11101150
return getMembersOfType<StructField>(declRef);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
interface IAdd
7+
{
8+
float addf(float u, float v);
9+
}
10+
11+
interface ISub
12+
{
13+
float subf(float u, float v);
14+
}
15+
16+
interface IAddAndSub
17+
{
18+
float addf(float u, float v);
19+
float subf(float u, float v);
20+
}
21+
22+
struct Simple : IAdd
23+
{
24+
float addf(float u, float v)
25+
{
26+
return u+v;
27+
}
28+
};
29+
30+
__extension Simple : ISub, IAddAndSub
31+
{
32+
float subf(float u, float v)
33+
{
34+
return u-v;
35+
}
36+
};
37+
38+
float testAddSub<T:IAddAndSub>(T t)
39+
{
40+
return t.subf(t.addf(1.0, 1.0), 1.0);
41+
}
42+
43+
[numthreads(4, 1, 1)]
44+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
45+
{
46+
Simple s;
47+
float outVal = testAddSub(s);
48+
outputBuffer[dispatchThreadID.x] = outVal;
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
3F800000
3+
3F800000
4+
3F800000

tests/compute/multi-interface.slang

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
interface IAdd
7+
{
8+
float addf(float u, float v);
9+
}
10+
11+
interface ISub
12+
{
13+
float subf(float u, float v);
14+
}
15+
16+
interface IAddAndSub
17+
{
18+
float addf(float u, float v);
19+
float subf(float u, float v);
20+
}
21+
22+
struct Simple : IAdd, ISub, IAddAndSub
23+
{
24+
float addf(float u, float v)
25+
{
26+
return u+v;
27+
}
28+
float subf(float u, float v)
29+
{
30+
return u-v;
31+
}
32+
};
33+
34+
float testAddSub<T:IAddAndSub>(T t)
35+
{
36+
return t.subf(t.addf(1.0, 1.0), 1.0);
37+
}
38+
39+
[numthreads(4, 1, 1)]
40+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
41+
{
42+
Simple s;
43+
float outVal = testAddSub(s);
44+
outputBuffer[dispatchThreadID.x] = outVal;
45+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
3F800000
3+
3F800000
4+
3F800000

0 commit comments

Comments
 (0)