Skip to content

Commit 0a81d11

Browse files
authored
Fix crash: dynamic dispatch of generic interface method. (shader-slang#1929)
* Fix crash: dynamic dispatch of generic interface method. * Fix memory error. Co-authored-by: Yong He <yhe@nvidia.com>
1 parent cf7ddda commit 0a81d11

4 files changed

+93
-6
lines changed

source/slang/slang-ir-lower-generic-call.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,6 @@ namespace Slang
218218
auto newCall = builder->emitCallInst(calleeRetType, newCallee, args);
219219
auto callInstType = callInst->getDataType();
220220
auto unpackInst = maybeUnpackValue(builder, callInstType, calleeRetType, newCall);
221-
callInst->replaceUsesWith(unpackInst);
222-
callInst->removeAndDeallocate();
223-
224221
// Unpack other `out` arguments.
225222
for (auto& item : argsToUnpack)
226223
{
@@ -229,6 +226,8 @@ namespace Slang
229226
auto unpackedVal = builder->emitUnpackAnyValue(originalValType, packedVal);
230227
builder->emitStore(item.dstArg, unpackedVal);
231228
}
229+
callInst->replaceUsesWith(unpackInst);
230+
callInst->removeAndDeallocate();
232231
}
233232

234233
IRInst* findInnerMostSpecializingBase(IRSpecialize* inst)

source/slang/slang-lower-to-ir.cpp

+21-3
Original file line numberDiff line numberDiff line change
@@ -6702,9 +6702,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
67026702
{
67036703
HashSet<IRInst*> valuesToClone;
67046704
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), returnType);
6705+
// For Function Types, we always clone all generic parameters regardless of whether
6706+
// the generic parameter appears in the function signature or not.
6707+
if (returnType->getOp() == kIROp_FuncType)
6708+
{
6709+
for (auto genericParam : parentGeneric->getParams())
6710+
{
6711+
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), genericParam);
6712+
}
6713+
}
67056714
if (valuesToClone.Count() == 0)
67066715
{
6707-
// If returnType is independent of generic parameters, set
6716+
// If the new generic has no parameters, set
67086717
// the generic inst's type to just `returnType`.
67096718
parentGeneric->setFullType((IRType*)returnType);
67106719
}
@@ -6727,8 +6736,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
67276736
}
67286737
IRInst* clonedReturnType = nullptr;
67296738
cloneEnv.mapOldValToNew.TryGetValue(returnType, clonedReturnType);
6730-
SLANG_ASSERT(clonedReturnType);
6731-
typeBuilder.emitReturn(clonedReturnType);
6739+
if (clonedReturnType)
6740+
{
6741+
// If the type has explicit dependency on generic parameters, use
6742+
// the cloned type.
6743+
typeBuilder.emitReturn(clonedReturnType);
6744+
}
6745+
else
6746+
{
6747+
// Otherwise just use the original type value directly.
6748+
typeBuilder.emitReturn(returnType);
6749+
}
67326750
parentGeneric->setFullType((IRType*)typeGeneric);
67336751
returnType = typeGeneric;
67346752
}
+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Test using generic interface methods with dynamic dispatch.
2+
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -vk -output-using-type
4+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -profile sm_6_0 -use-dxil -output-using-type
5+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx11 -profile sm_5_0 -output-using-type
6+
7+
interface IReturnsZero
8+
{
9+
float get();
10+
}
11+
12+
[anyValueSize(16)]
13+
interface IInterface
14+
{
15+
float run<Z:IReturnsZero>();
16+
}
17+
18+
struct UserDefinedPackedType
19+
{
20+
float3 val;
21+
uint flags;
22+
};
23+
24+
//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=gOutputBuffer
25+
RWStructuredBuffer<float> gOutputBuffer;
26+
27+
//TEST_INPUT: set gObj = new StructuredBuffer<UserDefinedPackedType>[new UserDefinedPackedType{[1.0, 2.0, 3.0], 3}, new UserDefinedPackedType{[2.0, 3.0, 4.0], 4}];
28+
RWStructuredBuffer<UserDefinedPackedType> gObj;
29+
30+
//TEST_INPUT: type_conformance FloatVal:IInterface = 3
31+
//TEST_INPUT: type_conformance Float4Val:IInterface = 4
32+
33+
[numthreads(1, 1, 1)]
34+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
35+
{
36+
float result = 0.0;
37+
for (int i = 0; i < 2; i++)
38+
{
39+
var rawObj = gObj.Load(i);
40+
IInterface dynamicObj = createDynamicObject<IInterface, UserDefinedPackedType>(rawObj.flags, rawObj);
41+
result += dynamicObj.run<ReturnsZero>();
42+
}
43+
gOutputBuffer[0] = result;
44+
}
45+
46+
struct ReturnsZero : IReturnsZero
47+
{
48+
float get() { return 0.0; }
49+
}
50+
struct FloatVal : IInterface
51+
{
52+
float val;
53+
float run<Z:IReturnsZero>()
54+
{
55+
Z z;
56+
return val + z.get();
57+
}
58+
};
59+
struct Float4Struct { float4 val; }
60+
struct Float4Val : IInterface
61+
{
62+
Float4Struct val;
63+
float run<Z:IReturnsZero>()
64+
{
65+
Z z;
66+
return val.val.x + val.val.y + z.get();
67+
}
68+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
type: float
2+
6.0

0 commit comments

Comments
 (0)