Skip to content

Commit 9913cfb

Browse files
[AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (shader-slang#5630)
* [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred * Fix failing tests * Update custom-derivative-generic.slang
1 parent 95125f2 commit 9913cfb

5 files changed

+133
-2
lines changed

source/slang/slang-check-decl.cpp

+75-1
Original file line numberDiff line numberDiff line change
@@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl(
1091510915
SemanticsContext::ExprLocalScope scope;
1091610916
auto ctx = visitor->withExprLocalScope(&scope);
1091710917
auto subVisitor = SemanticsVisitor(ctx);
10918-
auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
10918+
10919+
auto exprToCheck = attr->funcExpr;
10920+
10921+
// If this is a generic, we want to wrap the call to the derivative method
10922+
// with the generic parameters of the source.
10923+
//
10924+
if (as<GenericDecl>(funcDecl->parentDecl) && !as<GenericAppExpr>(attr->funcExpr))
10925+
{
10926+
auto genericDecl = as<GenericDecl>(funcDecl->parentDecl);
10927+
auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl);
10928+
auto appExpr = ctx.getASTBuilder()->create<GenericAppExpr>();
10929+
10930+
Index count = 0;
10931+
for (auto member : genericDecl->members)
10932+
{
10933+
if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member) ||
10934+
as<GenericTypePackParamDecl>(member))
10935+
count++;
10936+
}
10937+
10938+
appExpr->functionExpr = attr->funcExpr;
10939+
10940+
for (auto arg : substArgs)
10941+
{
10942+
if (count == 0)
10943+
break;
10944+
10945+
if (auto declRefType = as<DeclRefType>(arg))
10946+
{
10947+
auto baseTypeExpr = ctx.getASTBuilder()->create<SharedTypeExpr>();
10948+
baseTypeExpr->base.type = declRefType;
10949+
auto baseTypeType = ctx.getASTBuilder()->getOrCreate<TypeType>(declRefType);
10950+
baseTypeExpr->type.type = baseTypeType;
10951+
10952+
appExpr->arguments.add(baseTypeExpr);
10953+
}
10954+
else if (auto genericValParam = as<GenericParamIntVal>(arg))
10955+
{
10956+
auto declRef = genericValParam->getDeclRef();
10957+
appExpr->arguments.add(
10958+
subVisitor
10959+
.ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr));
10960+
}
10961+
else
10962+
{
10963+
SLANG_UNEXPECTED("Unhandled substitution arg type");
10964+
}
10965+
10966+
count--;
10967+
}
10968+
10969+
exprToCheck = appExpr;
10970+
}
10971+
10972+
auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx);
1091910973
attr->funcExpr = checkedFuncExpr;
1092010974
if (attr->args.getCount())
1092110975
attr->args[0] = attr->funcExpr;
@@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl(
1142711481
calleeDeclRef = calleeDeclRefExpr->declRef;
1142811482

1142911483
auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
11484+
11485+
if (!calleeFunc)
11486+
{
11487+
// If we couldn't find a direct function, it might be a generic.
11488+
if (auto genericDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
11489+
{
11490+
calleeFunc = as<FunctionDeclBase>(genericDecl->inner);
11491+
11492+
if (as<ErrorType>(resolved->type.type))
11493+
{
11494+
// If we can't resolve a type, something went wrong. If we're working with a generic
11495+
// decl, the most likely cause is a failure of generic argument inference.
11496+
//
11497+
visitor->getSink()->diagnose(
11498+
derivativeOfAttr,
11499+
Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
11500+
}
11501+
}
11502+
}
11503+
1143011504
if (!calleeFunc)
1143111505
{
1143211506
visitor->getSink()->diagnose(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
3+
4+
enum MyEnum { A, B, C };
5+
6+
[BackwardDerivative(mDiff)]
7+
float m<let M : MyEnum>(float x)
8+
{
9+
switch (M)
10+
{
11+
case MyEnum.A:
12+
return x * x;
13+
case MyEnum.B:
14+
return x;
15+
case MyEnum.C:
16+
return 3 * x;
17+
default:
18+
return 0;
19+
}
20+
}
21+
22+
void mDiff<let M : MyEnum>(inout DifferentialPair<float> x, float dResult)
23+
{
24+
switch (M)
25+
{
26+
case MyEnum.A:
27+
updateDiff(x, 2 * dResult * x.p);
28+
break;
29+
case MyEnum.B:
30+
updateDiff(x, dResult);
31+
break;
32+
case MyEnum.C:
33+
updateDiff(x, 3 * dResult);
34+
break;
35+
default:
36+
updateDiff(x, 0);
37+
break;
38+
}
39+
}
40+
41+
[Differentiable]
42+
float test(float x)
43+
{
44+
return m<MyEnum.A>(x);
45+
}
46+
47+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
48+
RWStructuredBuffer<float> outputBuffer;
49+
50+
[numthreads(1, 1, 1)]
51+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
52+
{
53+
var a = diffPair(3.0);
54+
__bwd_diff(test)(a, 1.0);
55+
outputBuffer[dispatchThreadID.x] = a.d;
56+
// CHECK: 6.0
57+
}
File renamed without changes.

tests/diagnostics/custom-derivative-generic.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ DifferentialPair<float> dd1(DifferentialPair<float> x)
3434
}
3535

3636
// CHECK-DAG: {{.*}}(37): error 31151
37-
[BackwardDerivative(f)]
37+
[BackwardDerivativeOf(f)]
3838
DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut)
3939
{
4040
var primal = x.p * x.p;

0 commit comments

Comments
 (0)