Skip to content

Commit f9f6a28

Browse files
authored
Support dependent generic constraints. (#4870)
* Support dependent generic constraints. * Fix warning. * Update comment. * Fix. * Add a test case to verify fix of #3804. * Address review.
1 parent 03e1e17 commit f9f6a28

10 files changed

+391
-118
lines changed

source/slang/slang-ast-decl.h

+6
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,9 @@ class GenericDecl : public ContainerDecl
550550
class GenericTypeParamDeclBase : public SimpleTypeDecl
551551
{
552552
SLANG_AST_CLASS(GenericTypeParamDeclBase)
553+
554+
// The index of the generic parameter.
555+
Index parameterIndex = -1;
553556
};
554557

555558
class GenericTypeParamDecl : public GenericTypeParamDeclBase
@@ -587,6 +590,9 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl
587590
class GenericValueParamDecl : public VarDeclBase
588591
{
589592
SLANG_AST_CLASS(GenericValueParamDecl)
593+
594+
// The index of the generic parameter.
595+
Index parameterIndex = 0;
590596
};
591597

592598
// An empty declaration (which might still have modifiers attached).

source/slang/slang-check-constraint.cpp

+185-109
Large diffs are not rendered by default.

source/slang/slang-check-decl.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -2520,18 +2520,22 @@ namespace Slang
25202520
// and likely a crash.
25212521
//
25222522
// Accessing the members via index side steps the issue.
2523+
2524+
Index parameterIndex = 0;
25232525
const auto& members = genericDecl->members;
25242526
for (Index i = 0; i < members.getCount(); ++i)
25252527
{
25262528
Decl* m = members[i];
25272529

2528-
if (auto typeParam = as<GenericTypeParamDecl>(m))
2530+
if (auto typeParam = as<GenericTypeParamDeclBase>(m))
25292531
{
25302532
ensureDecl(typeParam, DeclCheckState::ReadyForReference);
2533+
typeParam->parameterIndex = parameterIndex++;
25312534
}
25322535
else if (auto valParam = as<GenericValueParamDecl>(m))
25332536
{
25342537
ensureDecl(valParam, DeclCheckState::ReadyForReference);
2538+
valParam->parameterIndex = parameterIndex++;
25352539
}
25362540
else if (auto constraint = as<GenericTypeConstraintDecl>(m))
25372541
{

source/slang/slang-check-impl.h

+3
Original file line numberDiff line numberDiff line change
@@ -2102,6 +2102,7 @@ namespace Slang
21022102
};
21032103

21042104
Type* TryJoinVectorAndScalarType(
2105+
ConstraintSystem* constraints,
21052106
VectorExpressionType* vectorType,
21062107
BasicExpressionType* scalarType);
21072108

@@ -2196,11 +2197,13 @@ namespace Slang
21962197
ConversionCost getConversionCost(Type* toType, QualType fromType);
21972198

21982199
Type* _tryJoinTypeWithInterface(
2200+
ConstraintSystem* constraints,
21992201
Type* type,
22002202
Type* interfaceType);
22012203

22022204
// Try to compute the "join" between two types
22032205
Type* TryJoinTypes(
2206+
ConstraintSystem* constraints,
22042207
QualType left,
22052208
QualType right);
22062209

source/slang/slang-check-overload.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,8 @@ namespace Slang
883883

884884
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
885885
{
886-
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(genericDeclRef, substArgs, constraintDecl).as<GenericTypeConstraintDecl>();
886+
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
887+
genericDeclRef, newArgs.getArrayView(), constraintDecl).as<GenericTypeConstraintDecl>();
887888

888889
auto sub = getSub(m_astBuilder, constraintDeclRef);
889890
auto sup = getSup(m_astBuilder, constraintDeclRef);

source/slang/slang-ir-clone.cpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ static void _cloneInstDecorationsAndChildren(
152152
// require the second phase.
153153
//
154154
List<IRCloningOldNewPair> pairs;
155+
ShortList<IRCloningOldNewPair> paramPairs;
155156

156157
for( auto oldChild : oldInst->getDecorationsAndChildren() )
157158
{
@@ -172,7 +173,19 @@ static void _cloneInstDecorationsAndChildren(
172173
// on the child, and register it in our map from
173174
// old to new values.
174175
//
175-
auto newChild = cloneInstAndOperands(env, builder, oldChild);
176+
IRInst* newChild = nullptr;
177+
if (oldChild->getOp() == kIROp_Param)
178+
{
179+
// For parameters, don't clone its type just yet, since
180+
// the type might be a forward reference to things defined
181+
// later in the block that we haven't cloned and registered yet.
182+
newChild = builder->emitParam(nullptr);
183+
paramPairs.add({ oldChild, newChild });
184+
}
185+
else
186+
{
187+
newChild = cloneInstAndOperands(env, builder, oldChild);
188+
}
176189
env->mapOldValToNew.add(oldChild, newChild);
177190

178191
// If and only if the old child had decorations
@@ -181,10 +194,7 @@ static void _cloneInstDecorationsAndChildren(
181194
//
182195
if( oldChild->getFirstDecorationOrChild() )
183196
{
184-
IRCloningOldNewPair pair;
185-
pair.oldInst = oldChild;
186-
pair.newInst = newChild;
187-
pairs.add(pair);
197+
pairs.add({ oldChild, newChild });
188198
}
189199
}
190200

@@ -200,6 +210,17 @@ static void _cloneInstDecorationsAndChildren(
200210

201211
_cloneInstDecorationsAndChildren(env, module, oldChild, newChild);
202212
}
213+
214+
// For params, we can now clone their types since we have done cloning the entire block.
215+
for (auto pair : paramPairs)
216+
{
217+
auto oldParam = pair.oldInst;
218+
auto newParam = pair.newInst;
219+
220+
auto oldType = oldParam->getFullType();
221+
auto newType = (IRType*)findCloneForOperand(env, oldType);
222+
newParam->setFullType(newType);
223+
}
203224
}
204225

205226
// The public version of `cloneInstDecorationsAndChildren` is then

source/slang/slang-lower-to-ir.cpp

+33-2
Original file line numberDiff line numberDiff line change
@@ -9304,15 +9304,46 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
93049304
auto typeGeneric = typeBuilder.emitGeneric();
93059305
typeGeneric->setFullType(typeBuilder.getGenericKind());
93069306
typeBuilder.setInsertInto(typeGeneric);
9307-
typeBuilder.emitBlock();
9307+
auto block = typeBuilder.emitBlock();
93089308

9309+
struct ParamCloneInfo
9310+
{
9311+
IRParam* originalParam;
9312+
IRParam* clonedParam;
9313+
};
9314+
ShortList<ParamCloneInfo> paramCloneInfos;
9315+
93099316
for (auto child : parentGeneric->getFirstBlock()->getChildren())
93109317
{
93119318
if (valuesToClone.contains(child))
93129319
{
9313-
cloneInst(&cloneEnv, &typeBuilder, child);
9320+
if (child->getOp() == kIROp_Param)
9321+
{
9322+
// Params may have forward references in its type and
9323+
// decorations, so we just create a placeholder for it
9324+
// in this first pass.
9325+
IRParam* clonedParam = typeBuilder.emitParam(nullptr);
9326+
cloneEnv.mapOldValToNew[child] = clonedParam;
9327+
paramCloneInfos.add({ (IRParam*)child, clonedParam });
9328+
}
9329+
else
9330+
{
9331+
cloneInst(&cloneEnv, &typeBuilder, child);
9332+
}
93149333
}
93159334
}
9335+
9336+
// In a second pass, clone the types and decorations on params which may
9337+
// contain forward references.
9338+
for (auto param : paramCloneInfos)
9339+
{
9340+
typeBuilder.setInsertInto(param.clonedParam);
9341+
param.clonedParam->setFullType((IRType*)cloneInst(&cloneEnv, &typeBuilder, param.originalParam->getFullType()));
9342+
cloneInstDecorationsAndChildren(&cloneEnv, typeBuilder.getModule(), param.originalParam, param.clonedParam);
9343+
}
9344+
9345+
typeBuilder.setInsertInto(block);
9346+
93169347
IRInst* clonedReturnType = nullptr;
93179348
cloneEnv.mapOldValToNew.tryGetValue(returnType, clonedReturnType);
93189349
if (clonedReturnType)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
3+
4+
// Test that we can infer a generic type parameter from the base type of a dependent generic argument.
5+
6+
interface IFoo : IDefaultInitializable
7+
{
8+
int get();
9+
}
10+
11+
struct Foo : IFoo
12+
{
13+
int get()
14+
{
15+
return 1;
16+
}
17+
}
18+
19+
interface IBar<T : IFoo>
20+
{
21+
int getVal();
22+
}
23+
24+
struct Bar<T : IFoo> : IBar<T>
25+
{
26+
int getVal()
27+
{
28+
T t = T();
29+
return t.get();
30+
}
31+
}
32+
33+
int test<T:IFoo, B : IBar<T>>(B b)
34+
{
35+
return b.getVal();
36+
}
37+
38+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
39+
RWStructuredBuffer<int> outputBuffer;
40+
41+
[numthreads(1, 1, 1)]
42+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
43+
{
44+
Bar<Foo> obj2;
45+
let result = test(obj2);
46+
47+
// CHECK: 1
48+
outputBuffer[0] = result;
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<int> outputBuffer;
6+
7+
interface IOperation<T : __BuiltinFloatingPointType>
8+
{
9+
static T apply(T lhs, T rhs);
10+
};
11+
12+
T applyOp<T:__BuiltinFloatingPointType, TOp:IOperation<T>>(T lhs, T rhs)
13+
{
14+
return TOp::apply(lhs, rhs);
15+
}
16+
17+
struct AddOp<T : __BuiltinFloatingPointType> : IOperation<T>
18+
{
19+
static T apply(T lhs, T rhs)
20+
{
21+
return lhs + rhs;
22+
}
23+
}
24+
25+
[numthreads(1, 1, 1)]
26+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
27+
{
28+
let result = applyOp<float, AddOp<float>>(1.0, 2.0);
29+
30+
// CHECK: 3
31+
outputBuffer[0] = (int)result;
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
3+
4+
// Test that we can define a generic where one of the type parameter conforms to an generic interface parameterized on another
5+
// type parameter.
6+
7+
interface IFoo
8+
{
9+
int get();
10+
}
11+
12+
struct Foo : IFoo
13+
{
14+
int get()
15+
{
16+
return 1;
17+
}
18+
}
19+
20+
interface IBar<T : IFoo>
21+
{
22+
int getVal(T t);
23+
}
24+
25+
struct Bar<T : IFoo> : IBar<T>
26+
{
27+
int getVal(T t)
28+
{
29+
return t.get();
30+
}
31+
}
32+
33+
int test<T:IFoo, B : IBar<T>>(B b, T t)
34+
{
35+
return b.getVal(t);
36+
}
37+
38+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
39+
RWStructuredBuffer<int> outputBuffer;
40+
41+
[numthreads(1, 1, 1)]
42+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
43+
{
44+
Foo obj1;
45+
Bar<Foo> obj2;
46+
let result = test(obj2, obj1);
47+
48+
// CHECK: 1
49+
outputBuffer[0] = result;
50+
}

0 commit comments

Comments
 (0)