Skip to content

Commit 81c015e

Browse files
Fix diagnostics for [PreferRecompute] (shader-slang#5159)
* Fix diagnostics for [PreferRecompute] * Update dont-warn-on-simple-prefer-recompute.slang * Update slang-ir-autodiff.cpp * Update dont-warn-on-simple-prefer-recompute.slang * Update warn-on-prefer-recompute-side-effects.slang
1 parent a5d67ad commit 81c015e

4 files changed

+36
-4
lines changed

source/slang/slang-diagnostic-defs.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ DIAGNOSTIC(41904, Error, unableToAlignOf, "alignof could not be performed for ty
790790

791791
DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.")
792792

793-
DIAGNOSTIC(42050, Warning, potentialIssuesWithPreferRecomputeOnSideEffectMethod, "$0 has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [NoSideEffect]")
793+
DIAGNOSTIC(42050, Warning, potentialIssuesWithPreferRecomputeOnSideEffectMethod, "$0 has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [__NoSideEffect]")
794794

795795
DIAGNOSTIC(45001, Error, unresolvedSymbol, "unresolved external symbol '$0'.")
796796

source/slang/slang-ir-autodiff.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -2416,9 +2416,15 @@ void checkAutodiffPatterns(
24162416
if (auto func = as<IRFunc>(inst))
24172417
{
24182418
if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions
2419-
func->findDecoration<IRPreferRecomputeDecoration>() &&
2420-
!func->findDecoration<IRNoSideEffectDecoration>())
2419+
func->findDecoration<IRPreferRecomputeDecoration>())
24212420
{
2421+
// If we don't have any side-effect behavior, we should warn (note: read-none is a stronger
2422+
// guarantee than no-side-effect)
2423+
//
2424+
if (func->findDecoration<IRNoSideEffectDecoration>() ||
2425+
func->findDecoration<IRReadNoneDecoration>())
2426+
continue;
2427+
24222428
auto preferRecomputeDecor = func->findDecoration<IRPreferRecomputeDecoration>();
24232429
auto sideEffectBehavior = as<IRIntLit>(preferRecomputeDecor->getOperand(0))->getValue();
24242430

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -line-directive-mode none -stage compute -entry computeMain
2+
3+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
[BackwardDifferentiable]
7+
[PreferRecompute]
8+
float comp(float a, float b)
9+
{
10+
// CHECK: standard error = {
11+
// CHECK-NEXT: }
12+
return a * b;
13+
}
14+
15+
16+
[shader("compute")]
17+
[numthreads(128, 1, 1)]
18+
void computeMain(uint3 group_thread_id: SV_GroupThreadID, uint3 dispatch_thread_id: SV_DispatchThreadID)
19+
{
20+
DifferentialPair<float> value1 = diffPair(3.f, 0.f);
21+
DifferentialPair<float> value2 = diffPair(3.f, 0.f);
22+
23+
bwd_diff(comp)(value1, value2, 1.f);
24+
25+
outputBuffer[dispatch_thread_id.x] = value1.d;
26+
}

tests/autodiff/warn-on-prefer-recompute-side-effects.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ float get_thread_5_value(float v, uint group_thread_id)
1212
if(group_thread_id == 5)
1313
{
1414
s_shared = detach(v);
15-
// CHECK: tests/autodiff/warn-on-prefer-recompute-side-effects.slang(10): warning 42050: get_thread_5_value has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [NoSideEffect]
15+
// CHECK: tests/autodiff/warn-on-prefer-recompute-side-effects.slang(10): warning 42050: get_thread_5_value has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [__NoSideEffect]
1616
// CHECK: float get_thread_5_value(float v, uint group_thread_id)
1717
// CHECK: ^~~~~~~~~~~~~~~~~~
1818
}

0 commit comments

Comments
 (0)