Skip to content

Commit 6fae15c

Browse files
authored
Add diagnostic for calling non-bwd-diff func from bwd-diff func. (shader-slang#2602)
1 parent 0586f32 commit 6fae15c

8 files changed

+129
-50
lines changed

source/slang/slang-check-decl.cpp

+19-23
Original file line numberDiff line numberDiff line change
@@ -6894,38 +6894,34 @@ namespace Slang
68946894

68956895
bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func)
68966896
{
6897-
// A function is differentiable if it is marked as differentiable, or it
6898-
// has an associated derivative function.
6899-
if (func->findModifier<DifferentiableAttribute>())
6900-
return true;
6901-
for (auto assocDecl : getAssociatedDeclsForDecl(func))
6902-
{
6903-
switch (assocDecl.kind)
6904-
{
6905-
case DeclAssociationKind::ForwardDerivativeFunc:
6906-
case DeclAssociationKind::BackwardDerivativeFunc:
6907-
return true;
6908-
default:
6909-
break;
6910-
}
6911-
}
6912-
return false;
6897+
return getFuncDifferentiableLevel(func) != FunctionDifferentiableLevel::None;
69136898
}
69146899

69156900
bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func)
69166901
{
6917-
// A function is differentiable if it is marked as differentiable, or it
6918-
// has an associated derivative function.
6902+
return getFuncDifferentiableLevel(func) == FunctionDifferentiableLevel::Backward;
6903+
}
6904+
6905+
FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func)
6906+
{
69196907
if (func->findModifier<BackwardDifferentiableAttribute>())
6920-
return true;
6908+
return FunctionDifferentiableLevel::Backward;
69216909
if (func->findModifier<BackwardDerivativeAttribute>())
6922-
return true;
6910+
return FunctionDifferentiableLevel::Backward;
6911+
6912+
FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None;
6913+
if (func->findModifier<DifferentiableAttribute>())
6914+
diffLevel = FunctionDifferentiableLevel::Forward;
6915+
69236916
for (auto assocDecl : getAssociatedDeclsForDecl(func))
69246917
{
69256918
switch (assocDecl.kind)
69266919
{
69276920
case DeclAssociationKind::BackwardDerivativeFunc:
6928-
return true;
6921+
return FunctionDifferentiableLevel::Backward;
6922+
case DeclAssociationKind::ForwardDerivativeFunc:
6923+
diffLevel = FunctionDifferentiableLevel::Forward;
6924+
break;
69296925
default:
69306926
break;
69316927
}
@@ -6937,12 +6933,12 @@ namespace Slang
69376933
case BuiltinRequirementKind::DAddFunc:
69386934
case BuiltinRequirementKind::DMulFunc:
69396935
case BuiltinRequirementKind::DZeroFunc:
6940-
return true;
6936+
return FunctionDifferentiableLevel::Backward;
69416937
default:
69426938
break;
69436939
}
69446940
}
6945-
return false;
6941+
return diffLevel;
69466942
}
69476943

69486944
List<ExtensionDecl*> const& getCandidateExtensions(

source/slang/slang-check-expr.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -1967,6 +1967,10 @@ namespace Slang
19671967

19681968
if (m_parentDifferentiableAttr)
19691969
{
1970+
FunctionDifferentiableLevel callerDiffLevel = FunctionDifferentiableLevel::None;
1971+
if (m_parentFunc)
1972+
callerDiffLevel = getShared()->getFuncDifferentiableLevel(m_parentFunc);
1973+
19701974
if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr))
19711975
{
19721976
// Register types for final resolved invoke arguments again.
@@ -1978,7 +1982,8 @@ namespace Slang
19781982
{
19791983
if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl()))
19801984
{
1981-
if (getShared()->isDifferentiableFunc(calleeDecl))
1985+
auto calleeDiffLevel = getShared()->getFuncDifferentiableLevel(calleeDecl);
1986+
if (calleeDiffLevel >= callerDiffLevel)
19821987
{
19831988
if (!m_treatAsDifferentiableExpr)
19841989
{

source/slang/slang-check-impl.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
namespace Slang
1313
{
14-
14+
enum class FunctionDifferentiableLevel
15+
{
16+
None,
17+
Forward,
18+
Backward
19+
};
1520
/// Should the given `decl` be treated as a static rather than instance declaration?
1621
bool isEffectivelyStatic(
1722
Decl* decl);
@@ -292,6 +297,7 @@ namespace Slang
292297

293298
bool isDifferentiableFunc(FunctionDeclBase* func);
294299
bool isBackwardDifferentiableFunc(FunctionDeclBase* func);
300+
FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func);
295301

296302
private:
297303
/// Mapping from type declarations to the known extensiosn that apply to them

source/slang/slang-diagnostic-defs.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'vo
576576
DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.")
577577
DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2")
578578
DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.")
579-
DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-differentiable function `$0`, use 'no_diff' to clarify intention.")
579+
DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-$1-differentiable function `$0`, use 'no_diff' to clarify intention.")
580580
DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.")
581581
DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.")
582582
DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal")

source/slang/slang-ir-check-differentiability.cpp

+59-23
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
1212
DiagnosticSink* sink;
1313
AutoDiffSharedContext sharedContext;
1414

15-
HashSet<IRInst*> differentiableFunctions;
15+
enum DifferentiableLevel
16+
{
17+
Forward, Backward
18+
};
19+
Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;
1620

1721
CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
1822
: InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
@@ -59,7 +63,7 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
5963
}
6064

6165

62-
bool _isDifferentiableFuncImpl(IRInst* func)
66+
bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
6367
{
6468
func = getLeafFunc(func);
6569
if (!func)
@@ -71,32 +75,41 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
7175
{
7276
case kIROp_ForwardDerivativeDecoration:
7377
case kIROp_ForwardDifferentiableDecoration:
78+
if (level == DifferentiableLevel::Forward)
79+
return true;
80+
break;
7481
case kIROp_UserDefinedBackwardDerivativeDecoration:
7582
case kIROp_BackwardDerivativeDecoration:
7683
case kIROp_BackwardDifferentiableDecoration:
7784
return true;
85+
default:
86+
break;
7887
}
7988
}
8089
return false;
8190
}
8291

83-
bool isDifferentiableFunc(IRInst* func)
92+
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
8493
{
85-
switch (func->getOp())
94+
if (level == DifferentiableLevel::Forward)
8695
{
87-
case kIROp_ForwardDifferentiate:
88-
case kIROp_BackwardDifferentiate:
89-
return true;
90-
default:
91-
break;
96+
switch (func->getOp())
97+
{
98+
case kIROp_ForwardDifferentiate:
99+
case kIROp_BackwardDifferentiate:
100+
return true;
101+
default:
102+
break;
103+
}
92104
}
93105

94-
func = getSpecializedVal(func);
106+
func = getLeafFunc(func);
95107
if (!func)
96108
return false;
97109

98-
if (differentiableFunctions.Contains(func))
99-
return true;
110+
111+
if (auto existingLevel = differentiableFunctions.TryGetValue(func))
112+
return *existingLevel >= level;
100113

101114
if (func->findDecoration<IRTreatAsDifferentiableDecoration>())
102115
return true;
@@ -125,7 +138,10 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
125138
{
126139
if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey())
127140
{
128-
return true;
141+
if (as<IRBackwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Backward)
142+
return true;
143+
if (as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Forward)
144+
return true;
129145
}
130146
}
131147
}
@@ -135,7 +151,11 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
135151
{
136152
if (as<IRGeneric>(func))
137153
{
138-
return differentiableFunctions.Contains(func);
154+
if (auto existingLevel = differentiableFunctions.TryGetValue(func))
155+
{
156+
if (*existingLevel >= level)
157+
return true;
158+
}
139159
}
140160
}
141161
return false;
@@ -235,14 +255,18 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
235255
if (differentiableInputs == 0)
236256
sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput);
237257

258+
DifferentiableLevel requiredDiffLevel = DifferentiableLevel::Forward;
259+
if (isBackwardDifferentiableFunc(funcInst))
260+
requiredDiffLevel = DifferentiableLevel::Backward;
261+
238262
auto isInstProducingDiff = [&](IRInst* inst) -> bool
239263
{
240264
switch (inst->getOp())
241265
{
242266
case kIROp_FloatLit:
243267
return true;
244268
case kIROp_Call:
245-
return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
269+
return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel);
246270
case kIROp_Load:
247271
// We don't have more knowledge on whether diff is available at the destination address.
248272
// Just assume it is producing diff.
@@ -310,7 +334,7 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
310334
switch (inst->getOp())
311335
{
312336
case kIROp_Call:
313-
if (isDifferentiableFunc(as<IRCall>(inst)->getCallee()))
337+
if (isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel))
314338
{
315339
addToExpectDiffWorkList(inst);
316340
}
@@ -349,7 +373,11 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
349373
{
350374
if (auto call = as<IRCall>(inst))
351375
{
352-
sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee()));
376+
sink->diagnose(
377+
inst,
378+
Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
379+
getLeafFunc(call->getCallee()),
380+
requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
353381
}
354382
}
355383
switch (inst->getOp())
@@ -395,22 +423,30 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
395423
void processModule()
396424
{
397425
// Collect set of differentiable functions.
398-
HashSet<UnownedStringSlice> differentiableSymbolNames;
426+
HashSet<UnownedStringSlice> fwdDifferentiableSymbolNames, bwdDifferentiableSymbolNames;
399427
for (auto inst : module->getGlobalInsts())
400428
{
401-
if (_isDifferentiableFuncImpl(inst))
429+
if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Backward))
430+
{
431+
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
432+
bwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
433+
differentiableFunctions.Add(inst, DifferentiableLevel::Backward);
434+
}
435+
else if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Forward))
402436
{
403437
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
404-
differentiableSymbolNames.Add(linkageDecor->getMangledName());
405-
differentiableFunctions.Add(inst);
438+
fwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
439+
differentiableFunctions.Add(inst, DifferentiableLevel::Forward);
406440
}
407441
}
408442
for (auto inst : module->getGlobalInsts())
409443
{
410444
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
411445
{
412-
if (differentiableSymbolNames.Contains(linkageDecor->getMangledName()))
413-
differentiableFunctions.Add(inst);
446+
if (bwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
447+
differentiableFunctions[inst] = DifferentiableLevel::Backward;
448+
else if (fwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
449+
differentiableFunctions.AddIfNotExists(inst, DifferentiableLevel::Forward);
414450
}
415451
}
416452

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//DIAGNOSTIC_TEST:SIMPLE:
2+
3+
float nonDiff(float x)
4+
{
5+
return x;
6+
}
7+
8+
[ForwardDifferentiable]
9+
float f(float x)
10+
{
11+
return x * x;
12+
}
13+
14+
[BackwardDifferentiable]
15+
float g(float x)
16+
{
17+
float val = f(x + 1); // Error: f must also be backward-differentiable
18+
return val;
19+
}
20+
21+
[BackwardDifferentiable]
22+
float h(float x)
23+
{
24+
float val = 0;
25+
// no diagnostic by clarifying intention.
26+
val = no_diff(f(x + 1));
27+
return val;
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
result code = -1
2+
standard error = {
3+
tests/diagnostics/autodiff-data-flow-2.slang(18): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `f`, use 'no_diff' to clarify intention.
4+
float val = f(x + 1); // Error: f must also be backward-differentiable
5+
^
6+
}
7+
standard output = {
8+
}

tests/diagnostics/autodiff-data-flow.slang.expected

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
result code = -1
22
standard error = {
3-
tests/diagnostics/autodiff-data-flow.slang(15): error 41020: derivative cannot be propagated through call to non-differentiable function `nonDiff`, use 'no_diff' to clarify intention.
3+
tests/diagnostics/autodiff-data-flow.slang(15): error 41020: derivative cannot be propagated through call to non-forward-differentiable function `nonDiff`, use 'no_diff' to clarify intention.
44
val = nonDiff(x * 2.0f);
55
^
66
tests/diagnostics/autodiff-data-flow.slang(22): error 41021: a differentiable function must have at least one differentiable output.

0 commit comments

Comments
 (0)