Skip to content

Commit ca592d2

Browse files
Fix gradient behavior for min() and max() functions at boundaries. When input values are equal, the gradient is split evenly between both inputs. (#6411)
Co-authored-by: Yong He <yonghe@outlook.com>
1 parent 4d286aa commit ca592d2

File tree

5 files changed

+139
-69
lines changed

5 files changed

+139
-69
lines changed

source/slang/diff.meta.slang

+6-6
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
20742074
{
20752075
return DifferentialPair<T>(
20762076
max(dpx.p, dpy.p),
2077-
dpx.p > dpy.p ? dpx.d : dpy.d
2077+
dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
20782078
);
20792079
}
20802080

@@ -2084,8 +2084,8 @@ __generic<T : __BuiltinFloatingPointType>
20842084
[BackwardDerivativeOf(max)]
20852085
void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
20862086
{
2087-
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero());
2088-
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero());
2087+
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
2088+
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
20892089
}
20902090

20912091
VECTOR_MATRIX_BINARY_DIFF_IMPL(max)
@@ -2099,7 +2099,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
20992099
{
21002100
return DifferentialPair<T>(
21012101
min(dpx.p, dpy.p),
2102-
dpx.p < dpy.p ? dpx.d : dpy.d
2102+
dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
21032103
);
21042104
}
21052105

@@ -2109,8 +2109,8 @@ __generic<T : __BuiltinFloatingPointType>
21092109
[BackwardDerivativeOf(min)]
21102110
void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
21112111
{
2112-
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero());
2113-
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero());
2112+
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
2113+
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
21142114
}
21152115

21162116
VECTOR_MATRIX_BINARY_DIFF_IMPL(min)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
typedef DifferentialPair<float> dpfloat;
8+
typedef DifferentialPair<float2> dpfloat2;
9+
10+
[BackwardDifferentiable]
11+
float diffMax(float x, float y)
12+
{
13+
return max(x, y);
14+
}
15+
16+
[BackwardDifferentiable]
17+
float2 diffMax(float2 x, float2 y)
18+
{
19+
return max(x, y);
20+
}
21+
22+
[BackwardDifferentiable]
23+
float diffMin(float x, float y)
24+
{
25+
return min(x, y);
26+
}
27+
28+
[BackwardDifferentiable]
29+
float2 diffMin(float2 x, float2 y)
30+
{
31+
return min(x, y);
32+
}
33+
34+
[numthreads(1, 1, 1)]
35+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
36+
{
37+
// Test max() with x < y
38+
{
39+
dpfloat dpx = dpfloat(2.0, 1.0);
40+
dpfloat dpy = dpfloat(5.0, -2.0);
41+
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
42+
outputBuffer[0] = res.p; // Expect: 5.000000
43+
outputBuffer[1] = res.d; // Expect: -2.000000
44+
}
45+
46+
// Test max() with x == y
47+
{
48+
dpfloat dpx = dpfloat(3.0, 1.0);
49+
dpfloat dpy = dpfloat(3.0, -2.0);
50+
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
51+
outputBuffer[2] = res.p; // Expect: 3.000000
52+
outputBuffer[3] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
53+
}
54+
55+
// Test min() with x > y
56+
{
57+
dpfloat dpx = dpfloat(5.0, 1.0);
58+
dpfloat dpy = dpfloat(2.0, -2.0);
59+
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
60+
outputBuffer[4] = res.p; // Expect: 2.000000
61+
outputBuffer[5] = res.d; // Expect: -2.000000
62+
}
63+
64+
// Test min() with x == y
65+
{
66+
dpfloat dpx = dpfloat(3.0, 1.0);
67+
dpfloat dpy = dpfloat(3.0, -2.0);
68+
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
69+
outputBuffer[6] = res.p; // Expect: 3.000000
70+
outputBuffer[7] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
71+
}
72+
73+
// Test backward-mode max() with x == y
74+
{
75+
dpfloat dpx = dpfloat(3.0, 0.0);
76+
dpfloat dpy = dpfloat(3.0, 0.0);
77+
__bwd_diff(diffMax)(dpx, dpy, 1.0);
78+
outputBuffer[8] = dpx.d; // Expect: 0.500000 (half of gradient)
79+
outputBuffer[9] = dpy.d; // Expect: 0.500000 (half of gradient)
80+
}
81+
82+
// Test backward-mode min() with x == y
83+
{
84+
dpfloat dpx = dpfloat(3.0, 0.0);
85+
dpfloat dpy = dpfloat(3.0, 0.0);
86+
__bwd_diff(diffMin)(dpx, dpy, 1.0);
87+
outputBuffer[10] = dpx.d; // Expect: 0.500000 (half of gradient)
88+
outputBuffer[11] = dpy.d; // Expect: 0.500000 (half of gradient)
89+
}
90+
91+
// Test vector max() with x == y
92+
{
93+
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
94+
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
95+
dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy);
96+
outputBuffer[12] = res.p[0]; // Expect: 3.000000
97+
outputBuffer[13] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
98+
outputBuffer[14] = res.p[1]; // Expect: 4.000000
99+
outputBuffer[15] = res.d[1]; // Expect: 2.000000
100+
}
101+
102+
// Test vector min() with x == y
103+
{
104+
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
105+
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
106+
dpfloat2 res = __fwd_diff(diffMin)(dpx, dpy);
107+
outputBuffer[16] = res.p[0]; // Expect: 3.000000
108+
outputBuffer[17] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
109+
outputBuffer[18] = res.p[1]; // Expect: 2.000000
110+
outputBuffer[19] = res.d[1]; // Expect: -3.000000
111+
}
112+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
type: float
2+
5.000000
3+
-2.000000
4+
3.000000
5+
-0.500000
6+
2.000000
7+
-2.000000
8+
3.000000
9+
-0.500000
10+
0.500000
11+
0.500000
12+
0.500000
13+
0.500000
14+
3.000000
15+
-0.500000
16+
4.000000
17+
2.000000
18+
3.000000
19+
-0.500000
20+
2.000000
21+
-3.000000

tests/autodiff-dstdlib/dstdlib-max.slang

-52
This file was deleted.

tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt

-11
This file was deleted.

0 commit comments

Comments
 (0)