Skip to content

Commit e9bf8de

Browse files
author
Tim Foley
authored
Enable simple extensions of interface types (shader-slang#1521)
The big picture here is that an `extension` can now apply to an interface type and provide convenience methods for all types that implement that interface. Suppose you have an interface for counters: interface ICounter { [mutating] void add(int val); } and a type that implements it: struct SimpleCounter : ICounter { int _state = 0; ... } If a common operation in your codebase is to increment a counter by adding one, you would be faced with the problem of either: * Add the `increment()` operation to `ICounter`, and force every implementation to implement the new requirement * Add the `increment()` operation to concrete counter types as needed, and thus not be able to use it in generic code * Make `increment()` a global ("free") function, and force clients of counters to have to know which operations use member syntax (`c.add(...)`) and which use global function call syntax (`increment(c)`). The whole idea of `extension`s is to allow for another option that is better than all of the above: extension ICounter { [mutating] void increment() { this.add(1); } } The core of the implementation is relatively straightforward, and consists of two complementary pieces. The first piece is that when emitting a concrete IR entity (function/type/whatever) we treat any enclosing `interface` type (or `extension` thereof) a bit like an enclosing `GenericDecl`, and introduce an `IRGeneric` to wrap things. The generic `IRGeneric` has parameters representing the `This` type for the interface, along with the witness table that shows how `This` conforms to the interface itself. We thus end up with an IR version of `increment()` something like: void increment<This : ICounter>(This this) { this.add(1); } The second (complementary) fix is that when there is code that references this `increment()` operation, we don't treat it like an interface requirement (look up based on its key), and instead treat it like a generic (since that is how it is lowered now) and speciaize it to the information we can glean from the `ThisTypeSubstitution`. A related fix that is required here is that within the body of `increment`, when we perform `this.add`, we need to ensure that the lookup of `add` in the base interface properly takes into account the subtype relationship (`This : ICounter`) and encodes it into the lookup result, so that we get `((ICounter) this).add`, and properly generate code that looks up the `add` method in the witness table for `This`.
1 parent bc0d0f9 commit e9bf8de

6 files changed

+229
-43
lines changed

source/slang/slang-ast-val.h

+1
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ class TaggedUnionSubtypeWitness : public SubtypeWitness
200200
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
201201
};
202202

203+
/// A witness of the fact that `ThisType(someInterface) : someInterface`
203204
class ThisTypeSubtypeWitness : public SubtypeWitness
204205
{
205206
SLANG_CLASS(ThisTypeSubtypeWitness)

source/slang/slang-legalize-types.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,7 @@ LegalType legalizeTypeImpl(
11751175
else if( auto existentialPtrType = as<IRExistentialBoxType>(type))
11761176
{
11771177
// We want to transform an `ExistentialBox<T>` into just
1178-
// a `T`, with an `iplicitDeref` to make sure that any
1178+
// a `T`, with an `implicitDeref` to make sure that any
11791179
// pointer-related operations on the box Just Work.
11801180
//
11811181
// Note: the logic here doesn't have to deal with moving

source/slang/slang-lookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ static void _lookUpMembersInSuperTypeImpl(
634634
interfaceType,
635635
superIsInterfaceWitness);
636636

637-
_lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, thisType->interfaceDeclRef, request, ioResult, inBreadcrumbs);
637+
_lookUpMembersInSuperType(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, request, ioResult, inBreadcrumbs);
638638
}
639639
}
640640

source/slang/slang-lower-to-ir.cpp

+172-41
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,9 @@ struct IRGenContext
404404
// The IRType value to lower into for `ThisType`.
405405
IRInst* thisType = nullptr;
406406

407+
// The IR witness value to use for `ThisType`
408+
IRInst* thisTypeWitness = nullptr;
409+
407410
explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder)
408411
: shared(inShared)
409412
, astBuilder(inAstBuilder)
@@ -1416,6 +1419,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
14161419
return LoweredValInfo::simple(irWitnessTable);
14171420
}
14181421

1422+
LoweredValInfo visitThisTypeSubtypeWitness(ThisTypeSubtypeWitness* val)
1423+
{
1424+
SLANG_UNUSED(val);
1425+
return LoweredValInfo::simple(context->thisTypeWitness);
1426+
}
1427+
14191428
LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
14201429
{
14211430
// TODO: it is a bit messy here that the `ConstantIntVal` representation
@@ -2233,6 +2242,31 @@ DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl)
22332242
return declRef.as<D>();
22342243
}
22352244

2245+
static Type* _findReplacementThisParamType(
2246+
IRGenContext* context,
2247+
DeclRef<Decl> parentDeclRef)
2248+
{
2249+
if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() )
2250+
{
2251+
auto targetType = getTargetType(context->astBuilder, extensionDeclRef);
2252+
if(auto targetDeclRefType = as<DeclRefType>(targetType))
2253+
{
2254+
if(auto replacementType = _findReplacementThisParamType(context, targetDeclRefType->declRef))
2255+
return replacementType;
2256+
}
2257+
return targetType;
2258+
}
2259+
2260+
if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>())
2261+
{
2262+
auto thisType = context->astBuilder->create<ThisType>();
2263+
thisType->interfaceDeclRef = interfaceDeclRef;
2264+
return thisType;
2265+
}
2266+
2267+
return nullptr;
2268+
}
2269+
22362270
/// Get the type of the `this` parameter introduced by `parentDeclRef`, or null.
22372271
///
22382272
/// E.g., if `parentDeclRef` is a `struct` declaration, then this will
@@ -2247,20 +2281,13 @@ Type* getThisParamTypeForContainer(
22472281
IRGenContext* context,
22482282
DeclRef<Decl> parentDeclRef)
22492283
{
2250-
if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>())
2251-
{
2252-
auto thisType = context->astBuilder->create<ThisType>();
2253-
thisType->interfaceDeclRef = interfaceDeclRef;
2254-
return thisType;
2255-
}
2256-
else if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() )
2284+
if(auto replacementType = _findReplacementThisParamType(context, parentDeclRef))
2285+
return replacementType;
2286+
2287+
if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() )
22572288
{
22582289
return DeclRefType::create(context->astBuilder, aggTypeDeclRef);
22592290
}
2260-
else if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() )
2261-
{
2262-
return getTargetType(context->astBuilder, extensionDeclRef);
2263-
}
22642291

22652292
return nullptr;
22662293
}
@@ -5692,6 +5719,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
56925719
subContextStorage.env = &subEnvStorage;
56935720

56945721
subContextStorage.thisType = outerContext->thisType;
5722+
subContextStorage.thisTypeWitness = outerContext->thisTypeWitness;
56955723
}
56965724

56975725
IRBuilder* getBuilder() { return &subBuilderStorage; }
@@ -5962,6 +5990,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
59625990
auto thisType = getBuilder()->getThisType(irInterface);
59635991
subContext->thisType = thisType;
59645992

5993+
// TODO: Need to add an appropriate stand-in witness here.
5994+
subContext->thisTypeWitness = nullptr;
5995+
59655996
// Lower associated types first, so they can be referred to when lowering functions.
59665997
for (auto assocTypeDecl : decl->getMembersOfType<AssocTypeDecl>())
59675998
{
@@ -6303,6 +6334,45 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
63036334
return irGeneric;
63046335
}
63056336

6337+
IRGeneric* emitOuterInterfaceGeneric(
6338+
IRGenContext* subContext,
6339+
ContainerDecl* parentDecl,
6340+
DeclRefType* interfaceType,
6341+
Decl* leafDecl)
6342+
{
6343+
auto subBuilder = subContext->irBuilder;
6344+
6345+
// Of course, a generic might itself be nested inside of other generics...
6346+
emitOuterGenerics(subContext, parentDecl, leafDecl);
6347+
6348+
// We need to create an IR generic
6349+
6350+
auto irGeneric = subBuilder->emitGeneric();
6351+
subBuilder->setInsertInto(irGeneric);
6352+
6353+
auto irBlock = subBuilder->emitBlock();
6354+
subBuilder->setInsertInto(irBlock);
6355+
6356+
// The generic needs two parameters: one to represent the
6357+
// `ThisType`, and one to represent a witness that the
6358+
// `ThisType` conforms to the interface itself.
6359+
//
6360+
auto irThisTypeParam = subBuilder->emitParam(subBuilder->getTypeType());
6361+
6362+
auto irInterfaceType = lowerType(context, interfaceType);
6363+
auto irWitnessTableParam = subBuilder->emitParam(subBuilder->getWitnessTableType(irInterfaceType));
6364+
subBuilder->addTypeConstraintDecoration(irThisTypeParam, irInterfaceType);
6365+
6366+
// Now we need to wire up the IR parameters
6367+
// we created to be used as the `ThisType` in
6368+
// the body of the code.
6369+
//
6370+
subContext->thisType = irThisTypeParam;
6371+
subContext->thisTypeWitness = irWitnessTableParam;
6372+
6373+
return irGeneric;
6374+
}
6375+
63066376
// If the given `decl` is enclosed in any generic declarations, then
63076377
// emit IR-level generics to represent them.
63086378
// The `leafDecl` represents the inner-most declaration we are actually
@@ -6316,6 +6386,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
63166386
{
63176387
return emitOuterGeneric(subContext, genericAncestor, leafDecl);
63186388
}
6389+
6390+
// We introduce IR generics in one other case, where the input
6391+
// code wasn't visibly using generics: when a concrete member
6392+
// is defined on an interface type. In that case, the resulting
6393+
// definition needs to be generic on a parameter to represent
6394+
// the `ThisType` of the interface.
6395+
//
6396+
if(auto extensionAncestor = as<ExtensionDecl>(pp))
6397+
{
6398+
if(auto targetDeclRefType = as<DeclRefType>(extensionAncestor->targetType))
6399+
{
6400+
if(auto interfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>())
6401+
{
6402+
return emitOuterInterfaceGeneric(subContext, extensionAncestor, targetDeclRefType, leafDecl);
6403+
}
6404+
}
6405+
}
63196406
}
63206407

63216408
return nullptr;
@@ -7112,6 +7199,20 @@ bool canDeclLowerToAGeneric(Decl* decl)
71127199
return false;
71137200
}
71147201

7202+
static bool isInterfaceRequirement(Decl* decl)
7203+
{
7204+
auto ancestor = decl->parentDecl;
7205+
for(; ancestor; ancestor = ancestor->parentDecl )
7206+
{
7207+
if(as<InterfaceDecl>(ancestor))
7208+
return true;
7209+
7210+
if(as<ExtensionDecl>(ancestor))
7211+
return false;
7212+
}
7213+
return false;
7214+
}
7215+
71157216
LoweredValInfo emitDeclRef(
71167217
IRGenContext* context,
71177218
Decl* decl,
@@ -7204,36 +7305,66 @@ LoweredValInfo emitDeclRef(
72047305
return lowerType(context, thisTypeSubst->witness->sub);
72057306
}
72067307

7207-
// Somebody is trying to look up an interface requirement
7208-
// "through" some concrete type. We need to lower this decl-ref
7209-
// as a lookup of the corresponding member in a witness table.
7210-
//
7211-
// The witness table itself is referenced by the this-type
7212-
// substitution, so we can just lower that.
7213-
//
7214-
// Note: unlike the case for generics above, in the interface-lookup
7215-
// case, we don't end up caring about any further outer substitutions.
7216-
// That is because even if we are naming `ISomething<Foo>.doIt()`,
7217-
// a method inside a generic interface, we don't actually care
7218-
// about the substitution of `Foo` for the parameter `T` of
7219-
// `ISomething<T>`. That is because we really care about the
7220-
// witness table for the concrete type that conforms to `ISomething<Foo>`.
7221-
//
7222-
auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness);
7223-
//
7224-
// The key to use for looking up the interface member is
7225-
// derived from the declaration.
7226-
//
7227-
auto irRequirementKey = getInterfaceRequirementKey(context, decl);
7228-
//
7229-
// Those two pieces of information tell us what we need to
7230-
// do in order to look up the value that satisfied the requirement.
7231-
//
7232-
auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst(
7233-
type,
7234-
irWitnessTable,
7235-
irRequirementKey);
7236-
return LoweredValInfo::simple(irSatisfyingVal);
7308+
if(isInterfaceRequirement(decl))
7309+
{
7310+
// Somebody is trying to look up an interface requirement
7311+
// "through" some concrete type. We need to lower this decl-ref
7312+
// as a lookup of the corresponding member in a witness table.
7313+
//
7314+
// The witness table itself is referenced by the this-type
7315+
// substitution, so we can just lower that.
7316+
//
7317+
// Note: unlike the case for generics above, in the interface-lookup
7318+
// case, we don't end up caring about any further outer substitutions.
7319+
// That is because even if we are naming `ISomething<Foo>.doIt()`,
7320+
// a method inside a generic interface, we don't actually care
7321+
// about the substitution of `Foo` for the parameter `T` of
7322+
// `ISomething<T>`. That is because we really care about the
7323+
// witness table for the concrete type that conforms to `ISomething<Foo>`.
7324+
//
7325+
auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness);
7326+
//
7327+
// The key to use for looking up the interface member is
7328+
// derived from the declaration.
7329+
//
7330+
auto irRequirementKey = getInterfaceRequirementKey(context, decl);
7331+
//
7332+
// Those two pieces of information tell us what we need to
7333+
// do in order to look up the value that satisfied the requirement.
7334+
//
7335+
auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst(
7336+
type,
7337+
irWitnessTable,
7338+
irRequirementKey);
7339+
return LoweredValInfo::simple(irSatisfyingVal);
7340+
}
7341+
else
7342+
{
7343+
// This case is a reference to a member declaration of the interface
7344+
// (or added by an extension of the interface) that does *not*
7345+
// represent a requirement of the interface.
7346+
//
7347+
// Our policy is that concrete methods/members on an interface type
7348+
// are lowered as generics, where the generic parameter represents
7349+
// the `ThisType`.
7350+
//
7351+
auto genericVal = emitDeclRef(context, decl, thisTypeSubst->outer, context->irBuilder->getGenericKind());
7352+
auto irGenericVal = getSimpleVal(context, genericVal);
7353+
7354+
// In order to reference the member for a particular type, we
7355+
// specialize the generic for that type.
7356+
//
7357+
IRInst* irSubType = lowerType(context, thisTypeSubst->witness->sub);
7358+
IRInst* irSubTypeWitness = lowerSimpleVal(context, thisTypeSubst->witness);
7359+
7360+
IRInst* irSpecializeArgs[] = { irSubType, irSubTypeWitness };
7361+
auto irSpecializedVal = context->irBuilder->emitSpecializeInst(
7362+
type,
7363+
irGenericVal,
7364+
2,
7365+
irSpecializeArgs);
7366+
return LoweredValInfo::simple(irSpecializedVal);
7367+
}
72377368
}
72387369
else
72397370
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// interface-extension.slang
2+
3+
// Test that an `extension` applied to an interface type works as users expect
4+
5+
//TEST(compute):COMPARE_COMPUTE:
6+
7+
interface ICounter
8+
{
9+
[mutating] void add(int value);
10+
}
11+
12+
struct MyCounter : ICounter
13+
{
14+
int _state = 0;
15+
16+
[mutating] void add(int value) { _state += value; }
17+
}
18+
19+
extension ICounter
20+
{
21+
[mutating] void increment()
22+
{
23+
this.add(1);
24+
}
25+
}
26+
27+
void helper<T : ICounter>(in out T counter)
28+
{
29+
counter.increment();
30+
}
31+
32+
int test(int value)
33+
{
34+
MyCounter counter = { value };
35+
counter.increment();
36+
helper(counter);
37+
return counter._state;
38+
}
39+
40+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
41+
RWStructuredBuffer<int> outputBuffer;
42+
43+
[numthreads(4, 1, 1)]
44+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
45+
{
46+
uint tid = dispatchThreadID.x;
47+
int inVal = tid;
48+
int outVal = test(inVal);
49+
outputBuffer[tid] = outVal;
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2
2+
3
3+
4
4+
5

0 commit comments

Comments
 (0)