Skip to content

Commit a8bc598

Browse files
csyongheTim Foley
and
Tim Foley
authored
Allow calling a generic function with an existential value (dynamic dispatch) (shader-slang#1508)
* Allow calling a generic function with an existential value (dynamic dispatch). * Fixes per review comments. * Clean up implementation by having `openExistential` return `ExtractExistentialType` instead of a DeclRef to the interface with a `ThisTypeSubstitution`. * More cleanups Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com> Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 11748a7 commit a8bc598

8 files changed

+91
-4
lines changed

source/slang/slang-ast-type.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ bool ExtractExistentialType::_equalsImplOverride(Type* type)
676676

677677
HashCode ExtractExistentialType::_getHashCodeOverride()
678678
{
679-
return declRef.getHashCode();
679+
return combineHash(declRef.getHashCode(), interfaceDeclRef.getHashCode());
680680
}
681681

682682
Type* ExtractExistentialType::_createCanonicalTypeOverride()
@@ -688,13 +688,15 @@ Val* ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, Sub
688688
{
689689
int diff = 0;
690690
auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
691+
auto interfaceSubstDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff);
691692
if (!diff)
692693
return this;
693694

694695
(*ioDiff)++;
695696

696697
ExtractExistentialType* substValue = astBuilder->create<ExtractExistentialType>();
697-
substValue->declRef = declRef;
698+
substValue->declRef = substDeclRef;
699+
substValue->interfaceDeclRef = interfaceSubstDeclRef;
698700
return substValue;
699701
}
700702

source/slang/slang-ast-type.h

+1
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ class ExtractExistentialType : public Type
611611
SLANG_CLASS(ExtractExistentialType)
612612

613613
DeclRef<VarDeclBase> declRef;
614+
DeclRef<InterfaceDecl> interfaceDeclRef;
614615

615616
// Overrides should be public so base classes can access
616617
String _toStringOverride();

source/slang/slang-check-conformance.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,30 @@ namespace Slang
268268
}
269269
}
270270
}
271+
else if (auto extractExistentialType = as<ExtractExistentialType>(subType))
272+
{
273+
// An ExtractExistentialType from an existential value of type I
274+
// is a subtype of I.
275+
// We need to check and make sure the interface type of the `ExtractExistentialType`
276+
// is equal to `superType`.
277+
auto interfaceDeclRef = extractExistentialType->interfaceDeclRef;
278+
auto thisTypeSubst = findThisTypeSubstitution(interfaceDeclRef.substitutions.substitutions, interfaceDeclRef.getDecl());
279+
SLANG_ASSERT(thisTypeSubst && thisTypeSubst == interfaceDeclRef.substitutions.substitutions);
280+
// The interfaceDeclRef in `extractExistentialType` contains a `ThisTypeSubstitution`
281+
// to allow member lookup to return correct substituted types. Here we just need
282+
// to know if that interface is the same as the superType, so we need to exclude
283+
// the `ThisTypeSubstitution` from comparison.
284+
interfaceDeclRef.substitutions.substitutions = thisTypeSubst->outer;
285+
if (interfaceDeclRef.equals(superTypeDeclRef))
286+
{
287+
if (outWitness)
288+
{
289+
*outWitness = thisTypeSubst->witness;
290+
}
291+
return true;
292+
}
293+
return false;
294+
}
271295
else if(auto taggedUnionType = as<TaggedUnionType>(subType))
272296
{
273297
// A tagged union type conforms to an interface if all of

source/slang/slang-check-expr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ namespace Slang
116116
openedThisType->witness = openedWitness;
117117

118118
DeclRef<InterfaceDecl> substDeclRef = DeclRef<InterfaceDecl>(interfaceDecl, openedThisType);
119-
auto substDeclRefType = DeclRefType::create(m_astBuilder, substDeclRef);
119+
openedType->interfaceDeclRef = substDeclRef;
120120

121121
ExtractExistentialValueExpr* openedValue = m_astBuilder->create<ExtractExistentialValueExpr>();
122122
openedValue->declRef = varDeclRef;
123-
openedValue->type = QualType(substDeclRefType);
123+
openedValue->type = QualType(openedType);
124124

125125
return openedValue;
126126
});

source/slang/slang-check-overload.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,11 @@ namespace Slang
15171517
return CreateErrorExpr(expr);
15181518
}
15191519

1520+
for (auto& arg : expr->arguments)
1521+
{
1522+
arg = maybeOpenExistential(arg);
1523+
}
1524+
15201525
context.originalExpr = expr;
15211526
context.funcLoc = funcExpr->loc;
15221527

source/slang/slang-lookup.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,10 @@ static void _lookUpMembersInSuperTypeImpl(
587587

588588
_lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, declRef, request, ioResult, inBreadcrumbs);
589589
}
590+
if (auto extractExistentialType = as<ExtractExistentialType>(superType))
591+
{
592+
_lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, extractExistentialType->interfaceDeclRef, request, ioResult, inBreadcrumbs);
593+
}
590594
}
591595

592596
/// Perform lookup for `name` in the context of `type`.
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code
2+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -allow-dynamic-code
3+
4+
// Test dynamic dispatch code gen for specializing a generic with
5+
// an existential value.
6+
7+
[anyValueSize(16)]
8+
interface IInterface
9+
{
10+
int Compute(int inVal);
11+
};
12+
13+
int GenericCompute0(IInterface obj, int inVal)
14+
{
15+
return GenericCompute1(obj, obj, inVal);
16+
}
17+
18+
int GenericCompute1<T:IInterface>(T obj, IInterface obj2, int inVal)
19+
{
20+
return obj.Compute(inVal) + obj2.Compute(inVal);
21+
}
22+
23+
24+
struct Impl : IInterface
25+
{
26+
int base;
27+
int Compute(int inVal) { return base + inVal * inVal; }
28+
};
29+
30+
int test(int inVal)
31+
{
32+
Impl obj;
33+
obj.base = 1;
34+
return GenericCompute0(obj, inVal);
35+
}
36+
37+
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
38+
RWStructuredBuffer<int> outputBuffer : register(u0);
39+
40+
[numthreads(4, 1, 1)]
41+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
42+
{
43+
uint tid = dispatchThreadID.x;
44+
int inVal = outputBuffer[tid];
45+
int outVal = test(inVal);
46+
outputBuffer[tid] = outVal;
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2
2+
4
3+
A
4+
14

0 commit comments

Comments
 (0)