From 40ddf150b56bc84c6256ca2bc7116c912cbb5288 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 19 Feb 2025 11:43:01 -0500 Subject: [PATCH] Fix issue with `clamp`'s derivatives at the boundary. --- source/slang/diff.meta.slang | 12 ++--- tests/autodiff-dstdlib/dstdlib-clamp.slang | 44 ++++++++++++++++++- .../dstdlib-clamp.slang.expected.txt | 8 ++++ 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 6f2bd2cd45..38a3220be7 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1,4 +1,3 @@ - /// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode /// derivative implementation. /// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing @@ -80,7 +79,6 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; /// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified /// name. /// -/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; @@ -2150,7 +2148,7 @@ DifferentialPair __d_clamp(DifferentialPair dpx, DifferentialPair dpMin { return DifferentialPair( clamp(dpx.p, dpMin.p, dpMax.p), - dpx.p < dpMin.p ? dpMin.d : (dpx.p > dpMax.p ? dpMax.d : dpx.d)); + (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dpx.d : (dpx.p < dpMin.p ? dpMin.d : dpMax.d)); } __generic [BackwardDifferentiable] @@ -2158,9 +2156,11 @@ __generic [BackwardDerivativeOf(clamp)] void __d_clamp(inout DifferentialPair dpx, inout DifferentialPair dpMin, inout DifferentialPair dpMax, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p > dpMin.p && dpx.p < dpMax.p ? dOut : T.dzero()); - dpMin = diffPair(dpMin.p, dpx.p <= dpMin.p ? dOut : T.dzero()); - dpMax = diffPair(dpMax.p, dpx.p >= dpMax.p ? dOut : T.dzero()); + // Propagate the derivative to x if x is within [min, max] (including the boundaries). + dpx = diffPair(dpx.p, (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dOut : T.dzero()); + // If x is strictly below min or above max, the gradient is instead applied to the clamp bounds + dpMin = diffPair(dpMin.p, dpx.p < dpMin.p ? dOut : T.dzero()); + dpMax = diffPair(dpMax.p, dpx.p > dpMax.p ? dOut : T.dzero()); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp) diff --git a/tests/autodiff-dstdlib/dstdlib-clamp.slang b/tests/autodiff-dstdlib/dstdlib-clamp.slang index 32b1cc8eb0..3af12907a0 100644 --- a/tests/autodiff-dstdlib/dstdlib-clamp.slang +++ b/tests/autodiff-dstdlib/dstdlib-clamp.slang @@ -1,7 +1,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; typedef DifferentialPair dpfloat; @@ -178,4 +178,46 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[28] = dpmax.d.y; // Expected: 0.0 outputBuffer[29] = dpmax.d.z; // Expected: 0.3 } + + // New tests: Forward-mode tests for derivative propagation at the edges with clamp(x, 0, 1) + { + // Lower edge: x exactly = 0 + dpfloat dpx = dpfloat(0.0, 0.4); + dpfloat dpmin = dpfloat(0.0, 0.8); + dpfloat dpmax = dpfloat(1.0, 0.5); + dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax); + outputBuffer[30] = res.d; // Expected: 0.4 (propagated from x) + } + + { + // Upper edge: x exactly = 1 + dpfloat dpx = dpfloat(1.0, 0.7); + dpfloat dpmin = dpfloat(0.0, 0.8); + dpfloat dpmax = dpfloat(1.0, 0.9); + dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax); + outputBuffer[31] = res.d; // Expected: 0.7 (propagated from x) + } + + // Reverse-mode tests for derivative propagation at the edges with clamp(x, 0, 1) + { + // Lower edge: x exactly = 0 + dpfloat dpx = dpfloat(0.0, 0.0); + dpfloat dpmin = dpfloat(0.0, 0.0); + dpfloat dpmax = dpfloat(1.0, 0.0); + bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0); + outputBuffer[32] = dpx.d; // Expected: 1.0 (propagated from x) + outputBuffer[33] = dpmin.d; // Expected: 0.0 + outputBuffer[34] = dpmax.d; // Expected: 0.0 + } + + { + // Upper edge: x exactly = 1 + dpfloat dpx = dpfloat(1.0, 0.0); + dpfloat dpmin = dpfloat(0.0, 0.0); + dpfloat dpmax = dpfloat(1.0, 0.0); + bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0); + outputBuffer[35] = dpx.d; // Expected: 1.0 (propagated from x) + outputBuffer[36] = dpmin.d; // Expected: 0.0 + outputBuffer[37] = dpmax.d; // Expected: 0.0 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt index b00b0060bf..b18853e902 100644 --- a/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt @@ -29,3 +29,11 @@ type: float 0.000000 0.000000 0.300000 +0.400000 +0.700000 +1.000000 +0.000000 +0.000000 +1.000000 +0.000000 +0.000000