Skip to content

Commit 62079c5

Browse files
authored
Support associatedtype local variables and return values in dynamic dispatch code (shader-slang#1444)
* Refactor lower-generics pass into separate subpasses. * IR pass to generate witness table wrappers. * Support associatedtype local variables and return values in dynamic dispatch code.
1 parent 5758d16 commit 62079c5

7 files changed

+116
-5
lines changed

source/slang/slang-ir-generics-lowering-context.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace Slang
1515
case kIROp_ThisType:
1616
case kIROp_AssociatedType:
1717
case kIROp_InterfaceType:
18+
case kIROp_lookup_interface_method:
1819
return true;
1920
case kIROp_Specialize:
2021
{

source/slang/slang-ir-lower-generic-call.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,21 @@ namespace Slang
127127
translateCallInst(callInst, funcType, loweredFunc, specializeInst);
128128
}
129129

130+
void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst)
131+
{
132+
// If we see a call(lookup_interface_method(...), ...), we need to translate
133+
// all occurences of associatedtypes.
134+
auto funcType = cast<IRFuncType>(lookupInst->getDataType());
135+
auto loweredFunc = lookupInst;
136+
translateCallInst(callInst, funcType, loweredFunc, nullptr);
137+
}
138+
130139
void lowerCall(IRCall* callInst)
131140
{
132141
if (auto specializeInst = as<IRSpecialize>(callInst->getCallee()))
133142
lowerCallToSpecializedFunc(callInst, specializeInst);
143+
else if (auto lookupInst = as<IRLookupWitnessMethod>(callInst->getCallee()))
144+
lowerCallToInterfaceMethod(callInst, lookupInst);
134145
}
135146

136147
void processInst(IRInst* inst)

source/slang/slang-ir-lower-generic-function.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ namespace Slang
122122
auto paramType = param->getDataType();
123123
if (auto ptrType = as<IRPtrTypeBase>(paramType))
124124
paramType = ptrType->getValueType();
125-
if (isPointerOfType(paramType->getDataType(), kIROp_RTTIType))
125+
if (isPointerOfType(paramType->getDataType(), kIROp_RTTIType) ||
126+
paramType->op == kIROp_lookup_interface_method)
126127
{
127128
// Lower into a function parameter of raw pointer type.
128129
param->setFullType(builder.getRawPointerType());
@@ -277,6 +278,7 @@ namespace Slang
277278
// Update the type of lookupInst to the lowered type of the corresponding interface requirement val.
278279

279280
// If the requirement is a function, interfaceRequirementVal will be the lowered function type.
281+
// If the requirement is an associatedtype, interfaceRequirementVal will be Ptr<RTTIObject>.
280282
IRInst* interfaceRequirementVal = nullptr;
281283
auto witnessTableType = cast<IRWitnessTableType>(lookupInst->getWitnessTable()->getDataType());
282284
auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));

source/slang/slang-ir-lower-generic-var.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@ namespace Slang
1919
void processVarInst(IRInst* varInst)
2020
{
2121
// We process only var declarations that have type
22-
// `Ptr<IRParam>`.
22+
// `Ptr<IRParam>` or `Ptr<IRLookupInterfaceMethod>`.
23+
//
2324
// Due to the processing of `lowerGenericFunction`,
2425
// A local variable of generic type now appears as
25-
// `var X:Ptr<irParam:Ptr<RTTIType>>`
26+
// `var X:Ptr<y:Ptr<RTTIType>>`,
27+
// where y can be an IRParam if it is a generic type,
28+
// or an `lookup_interface_method` if it is an associated type.
2629
// We match this pattern and turn this inst into
27-
// `X:RawPtr = alloca(rtti_extract_size(irParam))`
30+
// `X:RTTIPtr(irParam) = alloca(irParam)`
2831
auto varTypeInst = varInst->getDataType();
2932
if (!varTypeInst)
3033
return;
@@ -34,7 +37,7 @@ namespace Slang
3437

3538
// `varTypeParam` represents a pointer to the RTTI object.
3639
auto varTypeParam = ptrType->getValueType();
37-
if (varTypeParam->op != kIROp_Param)
40+
if (varTypeParam->op != kIROp_Param && varTypeParam->op != kIROp_lookup_interface_method)
3841
return;
3942
if (!varTypeParam->getDataType())
4043
return;

source/slang/slang-lower-to-ir.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -3437,6 +3437,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
34373437
UNREACHABLE_RETURN(LoweredValInfo());
34383438
}
34393439

3440+
LoweredValInfo visitAssocTypeDecl(AssocTypeDecl* decl)
3441+
{
3442+
return LoweredValInfo::simple(context->irBuilder->getAssociatedType());
3443+
}
3444+
34403445
LoweredValInfo visitAssignExpr(AssignExpr* expr)
34413446
{
34423447
// Because our representation of lowered "values"
+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code
2+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -allow-dynamic-code
3+
4+
// Test dynamic dispatch code gen for associated-typed return values
5+
// and local variables.
6+
// TODO: test arguments of associated type.
7+
8+
interface IAssoc
9+
{
10+
int Compute();
11+
}
12+
13+
interface IInterface
14+
{
15+
associatedtype TAssoc : IAssoc;
16+
17+
[mutating]
18+
void SetVal(int inVal);
19+
20+
TAssoc GetAssoc();
21+
};
22+
23+
T.TAssoc CreateT_Assoc_Inner<T:IInterface>(int inVal)
24+
{
25+
T obj;
26+
obj.SetVal(inVal);
27+
return obj.GetAssoc();
28+
}
29+
30+
T.TAssoc CreateT_Assoc<T:IInterface>(int inVal)
31+
{
32+
return CreateT_Assoc_Inner<T>(inVal);
33+
}
34+
35+
T CreateT<T:IInterface>(int inVal)
36+
{
37+
T obj;
38+
obj.SetVal(inVal);
39+
return obj;
40+
}
41+
42+
struct Impl : IInterface
43+
{
44+
struct TAssoc : IAssoc
45+
{
46+
int base;
47+
int Compute()
48+
{
49+
return base;
50+
}
51+
};
52+
53+
TAssoc assoc;
54+
[mutating]
55+
void SetVal(int inVal)
56+
{
57+
assoc.base = inVal;
58+
}
59+
60+
TAssoc GetAssoc()
61+
{
62+
return assoc;
63+
}
64+
};
65+
66+
int test()
67+
{
68+
var obj = CreateT<Impl>(2);
69+
var obj2 = CreateT_Assoc<Impl>(1);
70+
// TODO: compiler crash if type parameter is missing.
71+
// (hitting lowering logic of TypeEqualityWitness)
72+
var obj3 = CreateT_Assoc_Inner<Impl>(1);
73+
return obj.GetAssoc().Compute() + obj2.Compute() + obj3.Compute();
74+
}
75+
76+
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
77+
RWStructuredBuffer<int> outputBuffer : register(u0);
78+
79+
[numthreads(4, 1, 1)]
80+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
81+
{
82+
uint tid = dispatchThreadID.x;
83+
int outVal = test();
84+
outputBuffer[tid] = outVal;
85+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
4
2+
4
3+
4
4+
4

0 commit comments

Comments
 (0)