Skip to content

Commit 1863fe1

Browse files
authored
Support generic constraints that are dependent on another generic param. (shader-slang#4091)
1 parent 7ef980f commit 1863fe1

File tree

3 files changed

+98
-6
lines changed

3 files changed

+98
-6
lines changed

source/slang/slang-ir-link.cpp

+27-2
Original file line numberDiff line numberDiff line change
@@ -746,16 +746,41 @@ void cloneGlobalValueWithCodeCommon(
746746
{
747747
IRBlock* ob = originalValue->getFirstBlock();
748748
IRBlock* cb = clonedValue->getFirstBlock();
749+
struct ParamCloneInfo
750+
{
751+
IRParam* originalParam;
752+
IRParam* clonedParam;
753+
};
754+
ShortList<ParamCloneInfo> paramCloneInfos;
749755
while (ob)
750756
{
751757
SLANG_ASSERT(cb);
752758

753759
builder->setInsertInto(cb);
754760
for (auto oi = ob->getFirstInst(); oi; oi = oi->getNextInst())
755761
{
756-
cloneInst(context, builder, oi);
762+
if (oi->getOp() == kIROp_Param)
763+
{
764+
// Params may have forward references in its type and
765+
// decorations, so we just create a placeholder for it
766+
// in this first pass.
767+
IRParam* clonedParam = builder->emitParam(nullptr);
768+
registerClonedValue(context, clonedParam, oi);
769+
paramCloneInfos.add({ (IRParam*)oi, clonedParam });
770+
}
771+
else
772+
{
773+
cloneInst(context, builder, oi);
774+
}
775+
}
776+
// Clone the type and decorations of parameters after all instructs in the block
777+
// have been cloned.
778+
for (auto param : paramCloneInfos)
779+
{
780+
builder->setInsertInto(param.clonedParam);
781+
param.clonedParam->setFullType((IRType*)cloneValue(context, param.originalParam->getFullType()));
782+
cloneDecorations(context, param.clonedParam, param.originalParam);
757783
}
758-
759784
ob = ob->getNextBlock();
760785
cb = cb->getNextBlock();
761786
}

source/slang/slang-lower-to-ir.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -8823,7 +8823,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
88238823
IRGenContext* subContext,
88248824
GenericTypeConstraintDecl* constraintDecl)
88258825
{
8826-
auto supType = lowerType(context, constraintDecl->sup.type);
8826+
auto supType = lowerType(subContext, constraintDecl->sup.type);
88278827
auto value = emitGenericConstraintValue(subContext, constraintDecl, supType);
88288828
subContext->setValue(constraintDecl, LoweredValInfo::simple(value));
88298829
}
@@ -8972,9 +8972,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
89728972
auto operand = value->getOperand(i);
89738973
markInstsToClone(valuesToClone, parentBlock, operand);
89748974
}
8975+
if (value->getFullType())
8976+
markInstsToClone(valuesToClone, parentBlock, value->getFullType());
8977+
for (auto child : value->getDecorationsAndChildren())
8978+
markInstsToClone(valuesToClone, parentBlock, child);
89758979
}
8976-
for (auto child : value->getChildren())
8977-
markInstsToClone(valuesToClone, parentBlock, child);
89788980
auto parent = parentBlock->getParent();
89798981
while (parent && parent != parentBlock)
89808982
{
@@ -9025,7 +9027,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
90259027
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), returnType);
90269028
// For Function Types, we always clone all generic parameters regardless of whether
90279029
// the generic parameter appears in the function signature or not.
9028-
if (returnType->getOp() == kIROp_FuncType)
9030+
if (returnType->getOp() == kIROp_FuncType ||
9031+
returnType->getOp() == kIROp_Generic)
90299032
{
90309033
for (auto genericParam : parentGeneric->getParams())
90319034
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
3+
4+
// Test that we can compile a generic function with a generic type constraint that is dependent on an
5+
// outer generic type parameter.
6+
7+
namespace ns{
8+
9+
public interface IBinaryElementWiseFunction<T>
10+
{
11+
public static T call(const in T lhs, const in T rhs);
12+
}
13+
public struct AddOp<T : IArithmetic> : IBinaryElementWiseFunction<T>
14+
{
15+
public static T call(const in T lhs, const in T rhs)
16+
{
17+
return lhs + rhs;
18+
}
19+
}
20+
public struct BinaryElementWiseInputData<T : IArithmetic>
21+
{
22+
T lhs;
23+
T rhs;
24+
25+
// Note: `U` is constrainted by `IBinaryElementWiseFunction<T>`, which is dependent on `T`,
26+
// that is another generic type parameter defined on the outer type.
27+
// This eventually leads to a IRGeneric where one param has a type that is dependent on
28+
// another param.
29+
// In this case, the IR for `test` after generic flattening will be:
30+
// ```
31+
// %g_test = IRGeneric
32+
// {
33+
// IRBlock
34+
// {
35+
// %T = IRParam : Type;
36+
// %T_w = IRParam : IRWitnessTableType<IArithmetic>;
37+
// %U = IRParam : Type;
38+
// %U_w = IRRaram : IRWitnessTableType<%s>; // note that the type here is a forward reference to %s
39+
// %s = specialize(%IBinaryElementWiseFunction, %T) // %s is dependent on %T.
40+
// ...
41+
// }
42+
// }
43+
//
44+
public T test<U : IBinaryElementWiseFunction<T>>(U x)
45+
{
46+
return x.call(lhs ,rhs);
47+
}
48+
}
49+
}
50+
51+
52+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
53+
RWStructuredBuffer<int> outputBuffer;
54+
55+
[shader("compute")]
56+
[numthreads(1,1,1)]
57+
void computeMain(uint3 threadId: SV_DispatchThreadID)
58+
{
59+
ns::BinaryElementWiseInputData<int> cb;
60+
cb.lhs = threadId.x + 1;
61+
cb.rhs = 2;
62+
// CHECK: 3
63+
outputBuffer[0] = cb.test(ns::AddOp<int>());
64+
}

0 commit comments

Comments
 (0)