Skip to content

Commit 8abae05

Browse files
authored
Merge pull request shader-slang#365 from csyonghe/extension
Allow extension of a concrete type to implement additional interface
2 parents e86ab5f + cff418b commit 8abae05

8 files changed

+211
-33
lines changed

source/slang/check.cpp

+30-19
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,12 @@ namespace Slang
10761076
{
10771077
// The user is asking for us to actually perform the conversion,
10781078
// so we need to generate an appropriate expression here.
1079-
1080-
throw "foo bar baz";
1079+
1080+
// YONGH: I am confused why we are not hitting this case before
1081+
//throw "foo bar baz";
1082+
// YONGH: temporary work around, may need to create the actual
1083+
// invocation expr to the constructor call
1084+
*outToExpr = fromExpr;
10811085
}
10821086

10831087
return true;
@@ -1794,7 +1798,7 @@ namespace Slang
17941798
// `requiredMemberDeclRef` is a required member of
17951799
// the interface.
17961800
RefPtr<Decl> findWitnessForInterfaceRequirement(
1797-
DeclRef<AggTypeDecl> typeDeclRef,
1801+
DeclRef<AggTypeDeclBase> typeDeclRef,
17981802
InheritanceDecl* inheritanceDecl,
17991803
DeclRef<InterfaceDecl> interfaceDeclRef,
18001804
DeclRef<Decl> requiredMemberDeclRef,
@@ -1833,24 +1837,22 @@ namespace Slang
18331837

18341838
// Make sure that by-name lookup is possible.
18351839
buildMemberDictionary(typeDeclRef.getDecl());
1836-
1837-
Decl* firstMemberOfName = nullptr;
1838-
typeDeclRef.getDecl()->memberDictionary.TryGetValue(name, firstMemberOfName);
1839-
1840-
if (!firstMemberOfName)
1840+
auto lookupResult = lookUpLocal(getSession(), this, name, typeDeclRef);
1841+
1842+
if (!lookupResult.isValid())
18411843
{
18421844
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef);
18431845
return nullptr;
18441846
}
18451847

18461848
// Iterate over the members and look for one that matches
18471849
// the expected signature for the requirement.
1848-
for (auto memberDecl = firstMemberOfName; memberDecl; memberDecl = memberDecl->nextInContainerWithSameName)
1850+
for (auto member : lookupResult)
18491851
{
1850-
if (doesMemberSatisfyRequirement(DeclRef<Decl>(memberDecl, typeDeclRef.substitutions), requiredMemberDeclRef, requirementWitness))
1851-
return memberDecl;
1852+
if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, requirementWitness))
1853+
return member.declRef.getDecl();
18521854
}
1853-
1855+
18541856
// No suitable member found, although there were candidates.
18551857
//
18561858
// TODO: Eventually we might want something akin to the current
@@ -1867,7 +1869,7 @@ namespace Slang
18671869
// (via the given `inheritanceDecl`) actually provides
18681870
// members to satisfy all the requirements in the interface.
18691871
bool checkInterfaceConformance(
1870-
DeclRef<AggTypeDecl> typeDeclRef,
1872+
DeclRef<AggTypeDeclBase> typeDeclRef,
18711873
InheritanceDecl* inheritanceDecl,
18721874
DeclRef<InterfaceDecl> interfaceDeclRef)
18731875
{
@@ -1925,7 +1927,7 @@ namespace Slang
19251927
}
19261928

19271929
bool checkConformanceToType(
1928-
DeclRef<AggTypeDecl> typeDeclRef,
1930+
DeclRef<AggTypeDeclBase> typeDeclRef,
19291931
InheritanceDecl* inheritanceDecl,
19301932
Type* baseType)
19311933
{
@@ -1953,7 +1955,7 @@ namespace Slang
19531955
// `inheritanceDecl` actually does what it needs to
19541956
// for that inheritance to be valid.
19551957
bool checkConformance(
1956-
DeclRef<AggTypeDecl> typeDecl,
1958+
DeclRef<AggTypeDeclBase> typeDecl,
19571959
InheritanceDecl* inheritanceDecl)
19581960
{
19591961
// Look at the type being inherited from, and validate
@@ -1963,10 +1965,10 @@ namespace Slang
19631965
}
19641966

19651967
bool checkConformance(
1966-
AggTypeDecl* typeDecl,
1968+
AggTypeDeclBase* typeDecl,
19671969
InheritanceDecl* inheritanceDecl)
19681970
{
1969-
return checkConformance(DeclRef<AggTypeDecl>(typeDecl, SubstitutionSet()), inheritanceDecl);
1971+
return checkConformance(DeclRef<AggTypeDeclBase>(typeDecl, SubstitutionSet()), inheritanceDecl);
19701972
}
19711973

19721974
void visitAggTypeDecl(AggTypeDecl* decl)
@@ -3479,10 +3481,11 @@ namespace Slang
34793481

34803482
// TODO: need to check that the target type names a declaration...
34813483

3484+
DeclRef<AggTypeDecl> aggTypeDeclRef;
34823485
if (auto targetDeclRefType = decl->targetType->As<DeclRefType>())
34833486
{
34843487
// Attach our extension to that type as a candidate...
3485-
if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>())
3488+
if (aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>())
34863489
{
34873490
auto aggTypeDecl = aggTypeDeclRef.getDecl();
34883491
decl->nextCandidateExtension = aggTypeDecl->candidateExtensions;
@@ -3516,6 +3519,14 @@ namespace Slang
35163519
EnsureDecl(m);
35173520
}
35183521

3522+
if (aggTypeDeclRef)
3523+
{
3524+
for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
3525+
{
3526+
checkConformance(aggTypeDeclRef.getDecl(), inheritanceDecl);
3527+
}
3528+
}
3529+
35193530
decl->SetCheckState(DeclCheckState::Checked);
35203531
}
35213532

@@ -3802,7 +3813,7 @@ namespace Slang
38023813

38033814
if( auto aggTypeDeclRef = declRef.As<AggTypeDecl>() )
38043815
{
3805-
for( auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(aggTypeDeclRef))
3816+
for( auto inheritanceDeclRef : getMembersOfTypeWithExt<InheritanceDecl>(aggTypeDeclRef))
38063817
{
38073818
EnsureDecl(inheritanceDeclRef.getDecl());
38083819

source/slang/lookup.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,19 @@ void DoLookupImpl(
345345

346346
// Now perform "local" lookup in the context of the container,
347347
// as if we were looking up a member directly.
348-
//
348+
349+
// if we are currently in an extension decl, perform local lookup
350+
// in the target decl we are extending
351+
if (auto extDeclRef = containerDeclRef.As<ExtensionDecl>())
352+
{
353+
if (auto targetDeclRef = extDeclRef.getDecl()->targetType->AsDeclRefType())
354+
{
355+
if (auto aggDeclRef = targetDeclRef->declRef.As<AggTypeDecl>())
356+
{
357+
containerDeclRef = extDeclRef.Substitute(aggDeclRef);
358+
}
359+
}
360+
}
349361
DoLocalLookupImpl(
350362
session,
351363
name, containerDeclRef, request, result, breadcrumbs);

source/slang/lower-to-ir.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -2788,19 +2788,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
27882788
// TODO: if this inheritance declaration is under an extension,
27892789
// then we should construct the type that is being extended,
27902790
// and not a reference to the extension itself.
2791-
auto parentDecl = inheritanceDecl->ParentDecl;
2792-
RefPtr<Type> type = DeclRefType::Create(
2793-
context->getSession(),
2794-
makeDeclRef(parentDecl));
27952791

2792+
auto parentDecl = inheritanceDecl->ParentDecl;
2793+
RefPtr<Type> type;
2794+
if (auto extParentDecl = dynamic_cast<ExtensionDecl*>(parentDecl))
2795+
{
2796+
type = extParentDecl->targetType.type;
2797+
if (auto declRefType = type.As<DeclRefType>())
2798+
{
2799+
if (auto aggTypeDecl = declRefType->declRef.As<AggTypeDecl>())
2800+
parentDecl = aggTypeDecl.getDecl();
2801+
}
2802+
}
2803+
else
2804+
{
2805+
type = DeclRefType::Create(
2806+
context->getSession(),
2807+
makeDeclRef(parentDecl));
2808+
}
27962809
// What is the super-type that we have declared we inherit from?
27972810
RefPtr<Type> superType = inheritanceDecl->base.type;
27982811

27992812
// Construct the mangled name for the witness table, which depends
28002813
// on the type that is conforming, and the type that it conforms to.
2801-
String mangledName = getMangledNameForConformanceWitness(
2802-
makeDeclRef(parentDecl),
2803-
superType);
2814+
String mangledName = getMangledNameForConformanceWitness(type, superType);
28042815

28052816
// Build an IR level witness table, which will represent the
28062817
// conformance of the type to its super-type.

source/slang/syntax.h

+46-6
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,30 @@ namespace Slang
978978
{
979979
return items.Count() > 1 ? items[0].declRef.GetName() : item.declRef.GetName();
980980
}
981+
LookupResultItem* begin()
982+
{
983+
if (isValid())
984+
{
985+
if (isOverloaded())
986+
return items.begin();
987+
else
988+
return &item;
989+
}
990+
else
991+
return nullptr;
992+
}
993+
LookupResultItem* end()
994+
{
995+
if (isValid())
996+
{
997+
if (isOverloaded())
998+
return items.end();
999+
else
1000+
return &item + 1;
1001+
}
1002+
else
1003+
return nullptr;
1004+
}
9811005
};
9821006

9831007
struct SemanticsVisitor;
@@ -1085,6 +1109,27 @@ namespace Slang
10851109
return FilteredMemberRefList<T>(declRef.getDecl()->Members, declRef.substitutions);
10861110
}
10871111

1112+
inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
1113+
{
1114+
return declRef.getDecl()->candidateExtensions;
1115+
}
1116+
1117+
template<typename T>
1118+
inline FilteredMemberRefList<T> getMembersOfTypeWithExt(DeclRef<ContainerDecl> const& declRef)
1119+
{
1120+
auto rs = getMembersOfType<T>(declRef);
1121+
if (auto aggDeclRef = declRef.As<AggTypeDecl>())
1122+
{
1123+
for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension)
1124+
{
1125+
auto extMembers = getMembersOfType<T>(DeclRef<ContainerDecl>(ext, declRef.substitutions));
1126+
const_cast<List<RefPtr<Decl>>&>(rs.decls).AddRange(extMembers.decls);
1127+
}
1128+
}
1129+
return rs;
1130+
}
1131+
1132+
10881133
inline RefPtr<Type> GetType(DeclRef<VarDeclBase> const& declRef)
10891134
{
10901135
return declRef.Substitute(declRef.getDecl()->type.Ptr());
@@ -1099,12 +1144,7 @@ namespace Slang
10991144
{
11001145
return declRef.Substitute(declRef.getDecl()->targetType.Ptr());
11011146
}
1102-
1103-
inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
1104-
{
1105-
return declRef.getDecl()->candidateExtensions;
1106-
}
1107-
1147+
11081148
inline FilteredMemberRefList<StructField> GetFields(DeclRef<StructDecl> const& declRef)
11091149
{
11101150
return getMembersOfType<StructField>(declRef);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
interface IAdd
7+
{
8+
float addf(float u, float v);
9+
}
10+
11+
interface ISub
12+
{
13+
float subf(float u, float v);
14+
}
15+
16+
interface IAddAndSub
17+
{
18+
float addf(float u, float v);
19+
float subf(float u, float v);
20+
}
21+
22+
struct Simple : IAdd
23+
{
24+
float base;
25+
float addf(float u, float v)
26+
{
27+
return u+v;
28+
}
29+
};
30+
31+
__extension Simple : ISub, IAddAndSub
32+
{
33+
float subf(float u, float v)
34+
{
35+
return base+u-v;
36+
}
37+
};
38+
39+
float testAddSub<T:IAddAndSub>(T t)
40+
{
41+
return t.subf(t.addf(1.0, 1.0), 1.0);
42+
}
43+
44+
[numthreads(4, 1, 1)]
45+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
46+
{
47+
Simple s;
48+
s.base = 0.0;
49+
float outVal = testAddSub(s);
50+
outputBuffer[dispatchThreadID.x] = outVal;
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
3F800000
3+
3F800000
4+
3F800000

tests/compute/multi-interface.slang

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
interface IAdd
7+
{
8+
float addf(float u, float v);
9+
}
10+
11+
interface ISub
12+
{
13+
float subf(float u, float v);
14+
}
15+
16+
interface IAddAndSub
17+
{
18+
float addf(float u, float v);
19+
float subf(float u, float v);
20+
}
21+
22+
struct Simple : IAdd, ISub, IAddAndSub
23+
{
24+
float addf(float u, float v)
25+
{
26+
return u+v;
27+
}
28+
float subf(float u, float v)
29+
{
30+
return u-v;
31+
}
32+
};
33+
34+
float testAddSub<T:IAddAndSub>(T t)
35+
{
36+
return t.subf(t.addf(1.0, 1.0), 1.0);
37+
}
38+
39+
[numthreads(4, 1, 1)]
40+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
41+
{
42+
Simple s;
43+
float outVal = testAddSub(s);
44+
outputBuffer[dispatchThreadID.x] = outVal;
45+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
3F800000
3+
3F800000
4+
3F800000

0 commit comments

Comments
 (0)