Skip to content

Commit a023792

Browse files
Fix issue with clamp's derivatives at the boundary. (#6403)
1 parent 0959d7e commit a023792

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

source/slang/diff.meta.slang

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
/// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode
32
/// derivative implementation.
43
/// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing
@@ -80,7 +79,6 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
8079
/// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified
8180
/// name.
8281
///
83-
///
8482
__attributeTarget(FunctionDeclBase)
8583
attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
8684

@@ -2150,17 +2148,19 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin
21502148
{
21512149
return DifferentialPair<T>(
21522150
clamp(dpx.p, dpMin.p, dpMax.p),
2153-
dpx.p < dpMin.p ? dpMin.d : (dpx.p > dpMax.p ? dpMax.d : dpx.d));
2151+
(dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dpx.d : (dpx.p < dpMin.p ? dpMin.d : dpMax.d));
21542152
}
21552153
__generic<T : __BuiltinFloatingPointType>
21562154
[BackwardDifferentiable]
21572155
[PreferRecompute]
21582156
[BackwardDerivativeOf(clamp)]
21592157
void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut)
21602158
{
2161-
dpx = diffPair(dpx.p, dpx.p > dpMin.p && dpx.p < dpMax.p ? dOut : T.dzero());
2162-
dpMin = diffPair(dpMin.p, dpx.p <= dpMin.p ? dOut : T.dzero());
2163-
dpMax = diffPair(dpMax.p, dpx.p >= dpMax.p ? dOut : T.dzero());
2159+
// Propagate the derivative to x if x is within [min, max] (including the boundaries).
2160+
dpx = diffPair(dpx.p, (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dOut : T.dzero());
2161+
// If x is strictly below min or above max, the gradient is instead applied to the clamp bounds
2162+
dpMin = diffPair(dpMin.p, dpx.p < dpMin.p ? dOut : T.dzero());
2163+
dpMax = diffPair(dpMax.p, dpx.p > dpMax.p ? dOut : T.dzero());
21642164
}
21652165
VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp)
21662166

tests/autodiff-dstdlib/dstdlib-clamp.slang

+43-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
22
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
33

4-
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
55
RWStructuredBuffer<float> outputBuffer;
66

77
typedef DifferentialPair<float> dpfloat;
@@ -178,4 +178,46 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
178178
outputBuffer[28] = dpmax.d.y; // Expected: 0.0
179179
outputBuffer[29] = dpmax.d.z; // Expected: 0.3
180180
}
181+
182+
// New tests: Forward-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
183+
{
184+
// Lower edge: x exactly = 0
185+
dpfloat dpx = dpfloat(0.0, 0.4);
186+
dpfloat dpmin = dpfloat(0.0, 0.8);
187+
dpfloat dpmax = dpfloat(1.0, 0.5);
188+
dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
189+
outputBuffer[30] = res.d; // Expected: 0.4 (propagated from x)
190+
}
191+
192+
{
193+
// Upper edge: x exactly = 1
194+
dpfloat dpx = dpfloat(1.0, 0.7);
195+
dpfloat dpmin = dpfloat(0.0, 0.8);
196+
dpfloat dpmax = dpfloat(1.0, 0.9);
197+
dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
198+
outputBuffer[31] = res.d; // Expected: 0.7 (propagated from x)
199+
}
200+
201+
// Reverse-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
202+
{
203+
// Lower edge: x exactly = 0
204+
dpfloat dpx = dpfloat(0.0, 0.0);
205+
dpfloat dpmin = dpfloat(0.0, 0.0);
206+
dpfloat dpmax = dpfloat(1.0, 0.0);
207+
bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
208+
outputBuffer[32] = dpx.d; // Expected: 1.0 (propagated from x)
209+
outputBuffer[33] = dpmin.d; // Expected: 0.0
210+
outputBuffer[34] = dpmax.d; // Expected: 0.0
211+
}
212+
213+
{
214+
// Upper edge: x exactly = 1
215+
dpfloat dpx = dpfloat(1.0, 0.0);
216+
dpfloat dpmin = dpfloat(0.0, 0.0);
217+
dpfloat dpmax = dpfloat(1.0, 0.0);
218+
bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
219+
outputBuffer[35] = dpx.d; // Expected: 1.0 (propagated from x)
220+
outputBuffer[36] = dpmin.d; // Expected: 0.0
221+
outputBuffer[37] = dpmax.d; // Expected: 0.0
222+
}
181223
}

tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt

+8
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,11 @@ type: float
2929
0.000000
3030
0.000000
3131
0.300000
32+
0.400000
33+
0.700000
34+
1.000000
35+
0.000000
36+
0.000000
37+
1.000000
38+
0.000000
39+
0.000000

0 commit comments

Comments
 (0)