Skip to content

Commit 3d75525

Browse files
Fix bug with overload resolution under nested generics (shader-slang#3107)
* Add test for generic param inference bug for nested generics * Change description & simplify test * Add expected file * Check parent decl before unifying type parameters
1 parent 00bd481 commit 3d75525

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

source/slang/slang-check-constraint.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -705,14 +705,16 @@ namespace Slang
705705
auto fstDeclRef = fstDeclRefType->getDeclRef();
706706

707707
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
708-
return TryUnifyTypeParam(constraints, typeParamDecl, snd);
708+
if (typeParamDecl->parentDecl == constraints.genericDecl)
709+
return TryUnifyTypeParam(constraints, typeParamDecl, snd);
709710

710711
if (auto sndDeclRefType = as<DeclRefType>(snd))
711712
{
712713
auto sndDeclRef = sndDeclRefType->getDeclRef();
713714

714715
if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
715-
return TryUnifyTypeParam(constraints, typeParamDecl, fst);
716+
if (typeParamDecl->parentDecl == constraints.genericDecl)
717+
return TryUnifyTypeParam(constraints, typeParamDecl, fst);
716718

717719
// can't be unified if they refer to different declarations.
718720
if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) return false;
@@ -816,7 +818,7 @@ namespace Slang
816818

817819
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
818820
{
819-
if(typeParamDecl->parentDecl == constraints.genericDecl )
821+
if(typeParamDecl->parentDecl == constraints.genericDecl)
820822
return TryUnifyTypeParam(constraints, typeParamDecl, snd);
821823
}
822824
}
@@ -827,7 +829,7 @@ namespace Slang
827829

828830
if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
829831
{
830-
if(typeParamDecl->parentDecl == constraints.genericDecl )
832+
if(typeParamDecl->parentDecl == constraints.genericDecl)
831833
return TryUnifyTypeParam(constraints, typeParamDecl, fst);
832834
}
833835
}
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//TEST(smoke,compute):COMPARE_COMPUTE: -shaderobj
2+
//TEST(smoke,compute):COMPARE_COMPUTE:-cpu -shaderobj
3+
4+
// Test overload resolution for nested generic definitions
5+
6+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
7+
RWStructuredBuffer<float> outputBuffer;
8+
9+
10+
__generic<T : __BuiltinFloatingPointType>
11+
struct Foo
12+
{
13+
T test(uint index, T x)
14+
{
15+
return __realCast<T, float>(1.f);
16+
}
17+
18+
__generic<let N: int>
19+
T test(vector<uint, N> index, T x)
20+
{
21+
return __realCast<T, float>(2.f);
22+
}
23+
};
24+
25+
26+
[numthreads(1, 1, 1)]
27+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
28+
{
29+
uint tid = dispatchThreadID.x + 2;
30+
31+
Foo<float> obj;
32+
33+
float outVal = obj.test(tid, 0.f);
34+
outputBuffer[0] = outVal; // Expect: 1
35+
36+
float outVal2 = obj.test(uint2(tid, tid), 0.f);
37+
outputBuffer[1] = outVal2; // Expect: 2
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
40000000
3+
0
4+
0

0 commit comments

Comments
 (0)