Skip to content

Commit ddd2905

Browse files
authored
Fix extension override behavior, and disallow extension on interface types. (#4977)
* Add a test to ensure extension does not override existing conformance. * Fix doc. * Update documentation. * Fix doc. * Add diagnostic test.
1 parent 56a3c02 commit ddd2905

8 files changed

+195
-68
lines changed

docs/user-guide/06-interfaces-generics.md

+28-23
Original file line numberDiff line numberDiff line change
@@ -742,17 +742,17 @@ See [if-let syntax](convenience-features.html#if_let-syntax) for more details.
742742
Extensions to Interfaces
743743
-----------------------------
744744

745-
In addition to extending ordinary types, you can define extensions on interfaces as well:
745+
In addition to extending ordinary types, you can define extensions on all types that conforms to some interface:
746+
746747
```csharp
747748
// An example interface.
748749
interface IFoo
749750
{
750751
int foo();
751752
}
752753

753-
// Extending `IFoo` with a new method requirement
754-
// with a default implementation.
755-
extension IFoo
754+
// Extend any type `T` that conforms to `IFoo` with a `bar` method.
755+
extension<T:IFoo> T
756756
{
757757
int bar() { return 0; }
758758
}
@@ -765,42 +765,47 @@ int use(IFoo foo)
765765
}
766766
```
767767

768-
Although the syntax of above listing suggests that we are extending an interface with additional requirements, this interpretation does not make logical sense in many ways. Consider a type `MyType` that exists before the extension is defined:
769-
```csharp
770-
struct MyType : IFoo
771-
{
772-
int foo() { return 0; }
773-
}
774-
```
768+
Note that `interface` types cannot be extended, because extending an `interface` with new requirements would make all existing types that conforms
769+
to the interface no longer valid.
775770

776-
If we extend the `IFoo` with new requirements, the existing `MyType` definition would become invalid since `MyType` no longer provides implementations to all interface requirements. Instead, what an `extension` on an interface `IFoo` means is that for all types that conforms to the `IFoo` interface and does not have a `bar` method defined, add a `bar` method defined in this extension to that type so that all `IFoo` typed values have a `bar` method defined. If a type already defines a matching `bar` method, then the existing method will always override the default method provided in the extension:
771+
In the presence of extensions, it is possible for a type to have multiple ways to
772+
conform to an interface. In this case, Slang will always prefer the more specific conformance
773+
over the generic one. For example, the following code illustrates this behavior:
777774

778775
```csharp
776+
interface IBase{}
779777
interface IFoo
780778
{
781779
int foo();
782780
}
783-
struct MyFoo1 : IFoo
781+
782+
// MyObject directly implements IBase:
783+
struct MyObject : IBase, IFoo
784784
{
785785
int foo() { return 0; }
786786
}
787-
extension IFoo
787+
788+
// Generic extension that applies to all types that conforms to `IBase`:
789+
extension<T:IBase> T : IFoo
788790
{
789-
int bar() { return 0; }
791+
int foo() { return 1; }
790792
}
791-
struct MyFoo2 : IFoo
793+
794+
int helper<T:IFoo>(T obj)
792795
{
793-
int foo() { return 0; }
794-
int bar() { return 1; }
796+
return obj.foo();
795797
}
796-
void test()
798+
799+
int test()
797800
{
798-
MyFoo1 f1;
799-
MyFoo2 f2;
800-
int a = f1.bar(); // a == 0, calling the method in the extension.
801-
int b = f2.bar(); // b == 1, calling the existing method in `MyFoo2`.
801+
MyObject obj;
802+
803+
// Returns 0, the conformance defined directly by the type
804+
// is preferred.
805+
return helper(obj);
802806
}
803807
```
808+
804809
This feature is similar to extension traits in Rust.
805810

806811

source/slang/slang-check-constraint.cpp

+25-15
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,22 @@ namespace Slang
8383
Type* interfaceType)
8484
{
8585
// The most basic test here should be: does the type declare conformance to the trait.
86-
if (isSubtype(type, interfaceType, constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None))
87-
return type;
88-
89-
// If additional subtype witnesses are provided for `type` in `constraints`,
90-
// try to use them to see if the interface is satisfied.
86+
9187
if (constraints->subTypeForAdditionalWitnesses == type)
9288
{
89+
// If additional subtype witnesses are provided for `type` in `constraints`,
90+
// try to use them to see if the interface is satisfied.
9391
if (constraints->additionalSubtypeWitnesses->containsKey(interfaceType))
9492
return type;
9593
}
94+
else
95+
{
96+
if (isSubtype(
97+
type,
98+
interfaceType,
99+
constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None))
100+
return type;
101+
}
96102

97103
// Just because `type` doesn't conform to the given `interfaceDeclRef`, that
98104
// doesn't necessarily indicate a failure. It is possible that we have a call
@@ -653,18 +659,22 @@ namespace Slang
653659
}
654660

655661
// Search for a witness that shows the constraint is satisfied.
656-
auto subTypeWitness = isSubtype(
657-
sub,
658-
sup,
659-
system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None);
660-
if (!subTypeWitness)
662+
SubtypeWitness* subTypeWitness = nullptr;
663+
if (sub == system->subTypeForAdditionalWitnesses)
661664
{
662-
if (sub == system->subTypeForAdditionalWitnesses)
663-
{
664-
// If no witness was found, try to find the witness from additional witness.
665-
system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness);
666-
}
665+
// If we are trying to find the subtype info for a type whose inheritance info is
666+
// being calculated, use what we have already known about the type.
667+
system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness);
667668
}
669+
else
670+
{
671+
// The general case is to initiate a subtype query.
672+
subTypeWitness = isSubtype(
673+
sub,
674+
sup,
675+
system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None);
676+
}
677+
668678
if(subTypeWitness)
669679
{
670680
// We found a witness, so it will become an (implicit) argument.

source/slang/slang-check-decl.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -8246,7 +8246,12 @@ namespace Slang
82468246
if (auto targetDeclRefType = as<DeclRefType>(decl->targetType))
82478247
{
82488248
// Attach our extension to that type as a candidate...
8249-
if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
8249+
if (targetDeclRefType->getDeclRef().as<InterfaceDecl>())
8250+
{
8251+
getSink()->diagnose(decl->targetType.exp, Diagnostics::invalidExtensionOnInterface, decl->targetType);
8252+
return;
8253+
}
8254+
else if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
82508255
{
82518256
auto aggTypeDecl = aggTypeDeclRef.getDecl();
82528257

@@ -8303,6 +8308,7 @@ namespace Slang
83038308
// to extend.
83048309
//
83058310
decl->targetType = CheckProperType(decl->targetType);
8311+
83068312
_validateExtensionDeclTargetType(decl);
83078313

83088314
_validateExtensionDeclMembers(decl);
@@ -9188,13 +9194,13 @@ namespace Slang
91889194
// look up extensions based on what would be visible to that
91899195
// module.
91909196
//
9191-
// We need to consider the extensions declared in the module itself,
9197+
// Extensions declared in the module itself should have already
9198+
// been registered when we check them, but we still need to bring
91929199
// along with everything the module imported.
91939200
//
91949201
// Note: there is an implicit assumption here that the `importedModules`
91959202
// member on the `SharedSemanticsContext` is accurate in this case.
91969203
//
9197-
_addCandidateExtensionsFromModule(m_module->getModuleDecl());
91989204
for( auto moduleDecl : this->importedModulesList )
91999205
{
92009206
_addCandidateExtensionsFromModule(moduleDecl);

source/slang/slang-check-inheritance.cpp

+27-25
Original file line numberDiff line numberDiff line change
@@ -422,40 +422,42 @@ namespace Slang
422422
{
423423
considerExtension(directAggTypeDeclRef, nullptr);
424424
}
425-
HashSet<Type*> supTypesConsideredForExtensionApplication;
426-
Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses;
427-
for (;;)
425+
if (!declRef.as<ExtensionDecl>())
428426
{
429-
// After we flatten the list of bases, we may discover additional opportunities
430-
// to apply extensions.
431-
List<DeclRef<AggTypeDecl>> supTypeWorkList;
432-
for (auto curFacet : directBaseFacets)
427+
HashSet<Type*> supTypesConsideredForExtensionApplication;
428+
Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses;
429+
for (;;)
433430
{
434-
if (!curFacet->subtypeWitness)
435-
continue;
436-
auto inheritanceInfo = getInheritanceInfo(curFacet->subtypeWitness->getSup(), circularityInfo);
437-
for (auto facet : inheritanceInfo.facets)
431+
// After we flatten the list of bases, we may discover additional opportunities
432+
// to apply extensions.
433+
List<DeclRef<AggTypeDecl>> supTypeWorkList;
434+
auto base = directBases.begin();
435+
for (auto baseFacet = directBaseFacets.getHead(); baseFacet.getImpl(); baseFacet = baseFacet->next)
438436
{
439-
if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>())
437+
for (auto facet : (*base)->facets)
440438
{
441-
SubtypeWitness* transitiveWitness = curFacet->subtypeWitness;
442-
transitiveWitness = astBuilder->getTransitiveSubtypeWitness(curFacet->subtypeWitness, facet->subtypeWitness);
443-
additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness);
444-
if (supTypesConsideredForExtensionApplication.add(facet->origin.type))
439+
if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>())
445440
{
446-
supTypeWorkList.add(interfaceDeclRef);
441+
SubtypeWitness* transitiveWitness = baseFacet->subtypeWitness;
442+
transitiveWitness = astBuilder->getTransitiveSubtypeWitness(baseFacet->subtypeWitness, facet->subtypeWitness);
443+
additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness);
444+
if (supTypesConsideredForExtensionApplication.add(facet->origin.type))
445+
{
446+
supTypeWorkList.add(interfaceDeclRef);
447+
}
447448
}
448449
}
450+
++base;
449451
}
452+
bool canExit = true;
453+
for (auto baseItem : supTypeWorkList)
454+
{
455+
if (considerExtension(baseItem, &additionalSubtypeWitnesses))
456+
canExit = false;
457+
}
458+
if (canExit)
459+
break;
450460
}
451-
bool canExit = true;
452-
for (auto baseItem : supTypeWorkList)
453-
{
454-
if (considerExtension(baseItem, &additionalSubtypeWitnesses))
455-
canExit = false;
456-
}
457-
if (canExit)
458-
break;
459461
}
460462

461463
// At this point, the list of direct bases (each with its own linearization)

source/slang/slang-check-overload.cpp

+30-2
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,16 @@ namespace Slang
11871187
return CountParameters(parentGeneric).required;
11881188
}
11891189

1190+
DeclRef<Decl> getParentDeclRef(DeclRef<Decl> declRef)
1191+
{
1192+
auto parent = declRef.getParent();
1193+
while (parent.as<GenericDecl>())
1194+
{
1195+
parent = parent.getParent();
1196+
}
1197+
return parent;
1198+
}
1199+
11901200
int SemanticsVisitor::CompareLookupResultItems(
11911201
LookupResultItem const& left,
11921202
LookupResultItem const& right)
@@ -1204,13 +1214,31 @@ namespace Slang
12041214
// directly (it is only visible through the requirement witness
12051215
// information for inheritance declarations).
12061216
//
1207-
auto leftDeclRefParent = left.declRef.getParent();
1208-
auto rightDeclRefParent = right.declRef.getParent();
1217+
auto leftDeclRefParent = getParentDeclRef(left.declRef);
1218+
auto rightDeclRefParent = getParentDeclRef(right.declRef);
12091219
bool leftIsInterfaceRequirement = isInterfaceRequirement(left.declRef.getDecl());
12101220
bool rightIsInterfaceRequirement = isInterfaceRequirement(right.declRef.getDecl());
12111221
if(leftIsInterfaceRequirement != rightIsInterfaceRequirement)
12121222
return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement);
12131223

1224+
// Prefer non-extension declarations over extension declarations.
1225+
bool leftIsExtension = as<ExtensionDecl>(leftDeclRefParent.getDecl()) != nullptr;
1226+
bool rightIsExtension = as<ExtensionDecl>(rightDeclRefParent.getDecl()) != nullptr;
1227+
if (leftIsExtension != rightIsExtension)
1228+
{
1229+
return int(leftIsExtension) - int(rightIsExtension);
1230+
}
1231+
else if (leftIsExtension)
1232+
{
1233+
// If both are declared in extensions, prefer the one that is least generic.
1234+
bool leftIsGeneric = leftDeclRefParent.getParent().as<GenericDecl>() != nullptr;
1235+
bool rightIsGeneric = rightDeclRefParent.getParent().as<GenericDecl>() != nullptr;
1236+
if (leftIsGeneric != rightIsGeneric)
1237+
{
1238+
return int(leftIsGeneric) - int(rightIsGeneric);
1239+
}
1240+
}
1241+
12141242
// Any decl is strictly better than a module decl.
12151243
bool leftIsModule = (as<ModuleDeclarationDecl>(left.declRef) != nullptr);
12161244
bool rightIsModule = (as<ModuleDeclarationDecl>(right.declRef) != nullptr);

source/slang/slang-diagnostic-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ DIAGNOSTIC(30832, Error, invalidTypeForInheritance, "type '$0' cannot be used fo
554554

555555
DIAGNOSTIC(30850, Error, invalidExtensionOnType, "type '$0' cannot be extended. `extension` can only be used to extend a nominal type.")
556556
DIAGNOSTIC(30851, Error, invalidMemberTypeInExtension, "$0 cannot be a part of an `extension`")
557+
DIAGNOSTIC(30852, Error, invalidExtensionOnInterface, "cannot extend interface type '$0'. consider using a generic extension: `extension<T:$0> T {...}`.")
557558

558559
// 309xx: subscripts
559560
DIAGNOSTIC(30900, Error, multiDimensionalArrayNotSupported, "multi-dimensional array is not supported.")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target cpp -stage compute -entry main -disable-specialization
2+
3+
interface IFoo{}
4+
5+
6+
// CHECK: ([[# @LINE+1]]): error 30852
7+
extension IFoo
8+
{
9+
int f() { return 0; }
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Test that the override behavior around extensions and generic extensions works as expected.
2+
3+
// When there are multiple ways for a type to conform to an interface, then the expected behavior
4+
// is that:
5+
// 1. If the type directly implements an interface, use that conformance.
6+
// 2. Otherwise, if there is a direct extension on the type that makes it conform to the interface, use that
7+
// extension.
8+
// 3. Otherwise, if there is a generic extension that makes the type conform to the interface, use that.
9+
10+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
11+
interface IFoo
12+
{
13+
int getVal();
14+
}
15+
16+
interface IBar
17+
{
18+
int getValPlusOne();
19+
}
20+
21+
interface IBaz
22+
{
23+
int getValPlusTwo();
24+
}
25+
26+
struct MyInt
27+
{
28+
int v;
29+
}
30+
31+
extension MyInt : IFoo
32+
{
33+
int getVal() { return v; }
34+
}
35+
36+
extension MyInt : IBar
37+
{
38+
int getValPlusOne() { return this.getVal() + 2; }
39+
}
40+
41+
extension<T: IFoo> T : IBar
42+
{
43+
int getValPlusOne() { return this.getVal() + 1; }
44+
}
45+
46+
int helper1<T:IBar>(T v){ return v.getValPlusOne();}
47+
int helper2<T:IFoo>(T v){ return v.getValPlusOne();}
48+
49+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
50+
RWStructuredBuffer<int> outputBuffer;
51+
52+
[numthreads(1,1,1)]
53+
void computeMain()
54+
{
55+
MyInt v = {1};
56+
57+
// CHECK: 3
58+
outputBuffer[0] = v.getValPlusOne(); // should call MyInt::ext::getValPlusOne();
59+
60+
// CHECK: 3
61+
outputBuffer[1] = helper1(v); // should call MyInt::ext::getValPlusOne();
62+
63+
// CHECK: 2
64+
outputBuffer[2] = helper2(v); // should call T::ext::getValPlusOne();
65+
}

0 commit comments

Comments
 (0)