Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient behavior for min() and max() functions at boundaries. #6411

Merged
merged 2 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions source/slang/diff.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -2074,7 +2074,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
return DifferentialPair<T>(
max(dpx.p, dpy.p),
dpx.p > dpy.p ? dpx.d : dpy.d
dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
);
}

Expand All @@ -2084,8 +2084,8 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(max)]
void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero());
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero());
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
}

VECTOR_MATRIX_BINARY_DIFF_IMPL(max)
Expand All @@ -2099,7 +2099,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
return DifferentialPair<T>(
min(dpx.p, dpy.p),
dpx.p < dpy.p ? dpx.d : dpy.d
dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
);
}

Expand All @@ -2109,8 +2109,8 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(min)]
void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero());
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero());
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
}

VECTOR_MATRIX_BINARY_DIFF_IMPL(min)
Expand Down
112 changes: 112 additions & 0 deletions tests/autodiff-dstdlib/dstdlib-max-min.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float2> dpfloat2;

[BackwardDifferentiable]
float diffMax(float x, float y)
{
return max(x, y);
}

[BackwardDifferentiable]
float2 diffMax(float2 x, float2 y)
{
return max(x, y);
}

[BackwardDifferentiable]
float diffMin(float x, float y)
{
return min(x, y);
}

[BackwardDifferentiable]
float2 diffMin(float2 x, float2 y)
{
return min(x, y);
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
// Test max() with x < y
{
dpfloat dpx = dpfloat(2.0, 1.0);
dpfloat dpy = dpfloat(5.0, -2.0);
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
outputBuffer[0] = res.p; // Expect: 5.000000
outputBuffer[1] = res.d; // Expect: -2.000000
}

// Test max() with x == y
{
dpfloat dpx = dpfloat(3.0, 1.0);
dpfloat dpy = dpfloat(3.0, -2.0);
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
outputBuffer[2] = res.p; // Expect: 3.000000
outputBuffer[3] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
}

// Test min() with x > y
{
dpfloat dpx = dpfloat(5.0, 1.0);
dpfloat dpy = dpfloat(2.0, -2.0);
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
outputBuffer[4] = res.p; // Expect: 2.000000
outputBuffer[5] = res.d; // Expect: -2.000000
}

// Test min() with x == y
{
dpfloat dpx = dpfloat(3.0, 1.0);
dpfloat dpy = dpfloat(3.0, -2.0);
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
outputBuffer[6] = res.p; // Expect: 3.000000
outputBuffer[7] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
}

// Test backward-mode max() with x == y
{
dpfloat dpx = dpfloat(3.0, 0.0);
dpfloat dpy = dpfloat(3.0, 0.0);
__bwd_diff(diffMax)(dpx, dpy, 1.0);
outputBuffer[8] = dpx.d; // Expect: 0.500000 (half of gradient)
outputBuffer[9] = dpy.d; // Expect: 0.500000 (half of gradient)
}

// Test backward-mode min() with x == y
{
dpfloat dpx = dpfloat(3.0, 0.0);
dpfloat dpy = dpfloat(3.0, 0.0);
__bwd_diff(diffMin)(dpx, dpy, 1.0);
outputBuffer[10] = dpx.d; // Expect: 0.500000 (half of gradient)
outputBuffer[11] = dpy.d; // Expect: 0.500000 (half of gradient)
}

// Test vector max() with x == y
{
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy);
outputBuffer[12] = res.p[0]; // Expect: 3.000000
outputBuffer[13] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
outputBuffer[14] = res.p[1]; // Expect: 4.000000
outputBuffer[15] = res.d[1]; // Expect: 2.000000
}

// Test vector min() with x == y
{
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
dpfloat2 res = __fwd_diff(diffMin)(dpx, dpy);
outputBuffer[16] = res.p[0]; // Expect: 3.000000
outputBuffer[17] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
outputBuffer[18] = res.p[1]; // Expect: 2.000000
outputBuffer[19] = res.d[1]; // Expect: -3.000000
}
}
21 changes: 21 additions & 0 deletions tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
type: float
5.000000
-2.000000
3.000000
-0.500000
2.000000
-2.000000
3.000000
-0.500000
0.500000
0.500000
0.500000
0.500000
3.000000
-0.500000
4.000000
2.000000
3.000000
-0.500000
2.000000
-3.000000
52 changes: 0 additions & 52 deletions tests/autodiff-dstdlib/dstdlib-max.slang

This file was deleted.

11 changes: 0 additions & 11 deletions tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt

This file was deleted.

Loading