From 2579fbea69a4e710521cafbb9967c4ec7cbc261c Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Thu, 20 Feb 2025 11:18:14 -0500 Subject: [PATCH] Fix gradient behavior for min() and max() functions at boundaries. When input values are equal, the gradient is split evenly between both inputs. --- source/slang/diff.meta.slang | 12 +- tests/autodiff-dstdlib/dstdlib-max-min.slang | 112 ++++++++++++++++++ .../dstdlib-max-min.slang.expected.txt | 21 ++++ tests/autodiff-dstdlib/dstdlib-max.slang | 52 -------- .../dstdlib-max.slang.expected.txt | 11 -- 5 files changed, 139 insertions(+), 69 deletions(-) create mode 100644 tests/autodiff-dstdlib/dstdlib-max-min.slang create mode 100644 tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt delete mode 100644 tests/autodiff-dstdlib/dstdlib-max.slang delete mode 100644 tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 38a3220be7..790dfaa798 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -2074,7 +2074,7 @@ DifferentialPair __d_max(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( max(dpx.p, dpy.p), - dpx.p > dpy.p ? dpx.d : dpy.d + dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d))) ); } @@ -2084,8 +2084,8 @@ __generic [BackwardDerivativeOf(max)] void __d_max(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero()); - dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero()); + dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); + dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); } VECTOR_MATRIX_BINARY_DIFF_IMPL(max) @@ -2099,7 +2099,7 @@ DifferentialPair __d_min(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( min(dpx.p, dpy.p), - dpx.p < dpy.p ? dpx.d : dpy.d + dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d))) ); } @@ -2109,8 +2109,8 @@ __generic [BackwardDerivativeOf(min)] void __d_min(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero()); - dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero()); + dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); + dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); } VECTOR_MATRIX_BINARY_DIFF_IMPL(min) diff --git a/tests/autodiff-dstdlib/dstdlib-max-min.slang b/tests/autodiff-dstdlib/dstdlib-max-min.slang new file mode 100644 index 0000000000..f37083706e --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-max-min.slang @@ -0,0 +1,112 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; +typedef DifferentialPair dpfloat2; + +[BackwardDifferentiable] +float diffMax(float x, float y) +{ + return max(x, y); +} + +[BackwardDifferentiable] +float2 diffMax(float2 x, float2 y) +{ + return max(x, y); +} + +[BackwardDifferentiable] +float diffMin(float x, float y) +{ + return min(x, y); +} + +[BackwardDifferentiable] +float2 diffMin(float2 x, float2 y) +{ + return min(x, y); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + // Test max() with x < y + { + dpfloat dpx = dpfloat(2.0, 1.0); + dpfloat dpy = dpfloat(5.0, -2.0); + dpfloat res = __fwd_diff(diffMax)(dpx, dpy); + outputBuffer[0] = res.p; // Expect: 5.000000 + outputBuffer[1] = res.d; // Expect: -2.000000 + } + + // Test max() with x == y + { + dpfloat dpx = dpfloat(3.0, 1.0); + dpfloat dpy = dpfloat(3.0, -2.0); + dpfloat res = __fwd_diff(diffMax)(dpx, dpy); + outputBuffer[2] = res.p; // Expect: 3.000000 + outputBuffer[3] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0) + } + + // Test min() with x > y + { + dpfloat dpx = dpfloat(5.0, 1.0); + dpfloat dpy = dpfloat(2.0, -2.0); + dpfloat res = __fwd_diff(diffMin)(dpx, dpy); + outputBuffer[4] = res.p; // Expect: 2.000000 + outputBuffer[5] = res.d; // Expect: -2.000000 + } + + // Test min() with x == y + { + dpfloat dpx = dpfloat(3.0, 1.0); + dpfloat dpy = dpfloat(3.0, -2.0); + dpfloat res = __fwd_diff(diffMin)(dpx, dpy); + outputBuffer[6] = res.p; // Expect: 3.000000 + outputBuffer[7] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0) + } + + // Test backward-mode max() with x == y + { + dpfloat dpx = dpfloat(3.0, 0.0); + dpfloat dpy = dpfloat(3.0, 0.0); + __bwd_diff(diffMax)(dpx, dpy, 1.0); + outputBuffer[8] = dpx.d; // Expect: 0.500000 (half of gradient) + outputBuffer[9] = dpy.d; // Expect: 0.500000 (half of gradient) + } + + // Test backward-mode min() with x == y + { + dpfloat dpx = dpfloat(3.0, 0.0); + dpfloat dpy = dpfloat(3.0, 0.0); + __bwd_diff(diffMin)(dpx, dpy, 1.0); + outputBuffer[10] = dpx.d; // Expect: 0.500000 (half of gradient) + outputBuffer[11] = dpy.d; // Expect: 0.500000 (half of gradient) + } + + // Test vector max() with x == y + { + dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0)); + dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0)); + dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy); + outputBuffer[12] = res.p[0]; // Expect: 3.000000 + outputBuffer[13] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0) + outputBuffer[14] = res.p[1]; // Expect: 4.000000 + outputBuffer[15] = res.d[1]; // Expect: 2.000000 + } + + // Test vector min() with x == y + { + dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0)); + dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0)); + dpfloat2 res = __fwd_diff(diffMin)(dpx, dpy); + outputBuffer[16] = res.p[0]; // Expect: 3.000000 + outputBuffer[17] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0) + outputBuffer[18] = res.p[1]; // Expect: 2.000000 + outputBuffer[19] = res.d[1]; // Expect: -3.000000 + } +} diff --git a/tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt new file mode 100644 index 0000000000..504343b586 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt @@ -0,0 +1,21 @@ +type: float +5.000000 +-2.000000 +3.000000 +-0.500000 +2.000000 +-2.000000 +3.000000 +-0.500000 +0.500000 +0.500000 +0.500000 +0.500000 +3.000000 +-0.500000 +4.000000 +2.000000 +3.000000 +-0.500000 +2.000000 +-3.000000 \ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang b/tests/autodiff-dstdlib/dstdlib-max.slang deleted file mode 100644 index 026914c8c2..0000000000 --- a/tests/autodiff-dstdlib/dstdlib-max.slang +++ /dev/null @@ -1,52 +0,0 @@ -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type - -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer -RWStructuredBuffer outputBuffer; - -typedef DifferentialPair dpfloat; -typedef DifferentialPair dpfloat2; - -[BackwardDifferentiable] -float diffMax(float x, float y) -{ - return max(x, y); -} - -[BackwardDifferentiable] -float2 diffMax(float2 x, float2 y) -{ - return max(x, y); -} - -[numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) -{ - { - dpfloat dpx = dpfloat(2.0, 1.0); - dpfloat dpy = dpfloat(5.0, -2.0); - dpfloat res = __fwd_diff(diffMax)(dpx, dpy); - outputBuffer[0] = res.p; // Expect: 5.000000 - outputBuffer[1] = res.d; // Expect: -2.000000 - } - - { - dpfloat2 dpx = dpfloat2(float2(-3.0, 4.0), float2(-1.0, -1.0)); - dpfloat2 dpy = dpfloat2(float2(1.0, 2.0), float2(2.0, 2.0)); - dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy); - outputBuffer[2] = res.p[0]; // Expect: 1.000000 - outputBuffer[3] = res.d[0]; // Expect: 2.000000 - outputBuffer[4] = res.p[1]; // Expect: 4.000000 - outputBuffer[5] = res.d[1]; // Expect: -1.000000 - } - - { - dpfloat2 dpx = dpfloat2(float2(2.0, 3.0), float2(0.0, 0.0)); - dpfloat2 dpy = dpfloat2(float2(5.0, 1.0), float2(0.0, 0.0)); - __bwd_diff(diffMax)(dpx, dpy, float2(1.0, 2.0)); - outputBuffer[6] = dpx.d[0]; // Expect: 0.000000 - outputBuffer[7] = dpx.d[1]; // Expect: 2.000000 - outputBuffer[8] = dpy.d[0]; // Expect: 1.000000 - outputBuffer[9] = dpy.d[1]; // Expect: 0.000000 - } -} diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt deleted file mode 100644 index 4cc1e9533c..0000000000 --- a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt +++ /dev/null @@ -1,11 +0,0 @@ -type: float -5.000000 --2.000000 -1.000000 -2.000000 -4.000000 --1.000000 -0.000000 -2.000000 -1.000000 -0.000000 \ No newline at end of file