Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with clamp's derivatives at the boundary. #6403

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions source/slang/diff.meta.slang
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode
/// derivative implementation.
/// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing
Expand Down Expand Up @@ -80,7 +79,6 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
/// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified
/// name.
///
///
__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;

Expand Down Expand Up @@ -2150,17 +2148,19 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin
{
return DifferentialPair<T>(
clamp(dpx.p, dpMin.p, dpMax.p),
dpx.p < dpMin.p ? dpMin.d : (dpx.p > dpMax.p ? dpMax.d : dpx.d));
(dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dpx.d : (dpx.p < dpMin.p ? dpMin.d : dpMax.d));
}
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[PreferRecompute]
[BackwardDerivativeOf(clamp)]
void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut)
{
dpx = diffPair(dpx.p, dpx.p > dpMin.p && dpx.p < dpMax.p ? dOut : T.dzero());
dpMin = diffPair(dpMin.p, dpx.p <= dpMin.p ? dOut : T.dzero());
dpMax = diffPair(dpMax.p, dpx.p >= dpMax.p ? dOut : T.dzero());
// Propagate the derivative to x if x is within [min, max] (including the boundaries).
dpx = diffPair(dpx.p, (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dOut : T.dzero());
// If x is strictly below min or above max, the gradient is instead applied to the clamp bounds
dpMin = diffPair(dpMin.p, dpx.p < dpMin.p ? dOut : T.dzero());
dpMax = diffPair(dpMax.p, dpx.p > dpMax.p ? dOut : T.dzero());
}
VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp)

Expand Down
44 changes: 43 additions & 1 deletion tests/autodiff-dstdlib/dstdlib-clamp.slang
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;
Expand Down Expand Up @@ -178,4 +178,46 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[28] = dpmax.d.y; // Expected: 0.0
outputBuffer[29] = dpmax.d.z; // Expected: 0.3
}

// New tests: Forward-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
{
// Lower edge: x exactly = 0
dpfloat dpx = dpfloat(0.0, 0.4);
dpfloat dpmin = dpfloat(0.0, 0.8);
dpfloat dpmax = dpfloat(1.0, 0.5);
dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
outputBuffer[30] = res.d; // Expected: 0.4 (propagated from x)
}

{
// Upper edge: x exactly = 1
dpfloat dpx = dpfloat(1.0, 0.7);
dpfloat dpmin = dpfloat(0.0, 0.8);
dpfloat dpmax = dpfloat(1.0, 0.9);
dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
outputBuffer[31] = res.d; // Expected: 0.7 (propagated from x)
}

// Reverse-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
{
// Lower edge: x exactly = 0
dpfloat dpx = dpfloat(0.0, 0.0);
dpfloat dpmin = dpfloat(0.0, 0.0);
dpfloat dpmax = dpfloat(1.0, 0.0);
bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
outputBuffer[32] = dpx.d; // Expected: 1.0 (propagated from x)
outputBuffer[33] = dpmin.d; // Expected: 0.0
outputBuffer[34] = dpmax.d; // Expected: 0.0
}

{
// Upper edge: x exactly = 1
dpfloat dpx = dpfloat(1.0, 0.0);
dpfloat dpmin = dpfloat(0.0, 0.0);
dpfloat dpmax = dpfloat(1.0, 0.0);
bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
outputBuffer[35] = dpx.d; // Expected: 1.0 (propagated from x)
outputBuffer[36] = dpmin.d; // Expected: 0.0
outputBuffer[37] = dpmax.d; // Expected: 0.0
}
}
8 changes: 8 additions & 0 deletions tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ type: float
0.000000
0.000000
0.300000
0.400000
0.700000
1.000000
0.000000
0.000000
1.000000
0.000000
0.000000
Loading