Skip to content

Commit 2579fbe

Browse files
committed
Fix gradient behavior for min() and max() functions at boundaries. When input values are equal, the gradient is split evenly between both inputs.
1 parent 9580e31 commit 2579fbe

File tree

5 files changed

+139
-69
lines changed

5 files changed

+139
-69
lines changed

source/slang/diff.meta.slang

+6-6
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
20742074
{
20752075
return DifferentialPair<T>(
20762076
max(dpx.p, dpy.p),
2077-
dpx.p > dpy.p ? dpx.d : dpy.d
2077+
dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
20782078
);
20792079
}
20802080

@@ -2084,8 +2084,8 @@ __generic<T : __BuiltinFloatingPointType>
20842084
[BackwardDerivativeOf(max)]
20852085
void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
20862086
{
2087-
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero());
2088-
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero());
2087+
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
2088+
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
20892089
}
20902090

20912091
VECTOR_MATRIX_BINARY_DIFF_IMPL(max)
@@ -2099,7 +2099,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
20992099
{
21002100
return DifferentialPair<T>(
21012101
min(dpx.p, dpy.p),
2102-
dpx.p < dpy.p ? dpx.d : dpy.d
2102+
dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
21032103
);
21042104
}
21052105

@@ -2109,8 +2109,8 @@ __generic<T : __BuiltinFloatingPointType>
21092109
[BackwardDerivativeOf(min)]
21102110
void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
21112111
{
2112-
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero());
2113-
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero());
2112+
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
2113+
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
21142114
}
21152115

21162116
VECTOR_MATRIX_BINARY_DIFF_IMPL(min)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
typedef DifferentialPair<float> dpfloat;
8+
typedef DifferentialPair<float2> dpfloat2;
9+
10+
[BackwardDifferentiable]
11+
float diffMax(float x, float y)
12+
{
13+
return max(x, y);
14+
}
15+
16+
[BackwardDifferentiable]
17+
float2 diffMax(float2 x, float2 y)
18+
{
19+
return max(x, y);
20+
}
21+
22+
[BackwardDifferentiable]
23+
float diffMin(float x, float y)
24+
{
25+
return min(x, y);
26+
}
27+
28+
[BackwardDifferentiable]
29+
float2 diffMin(float2 x, float2 y)
30+
{
31+
return min(x, y);
32+
}
33+
34+
[numthreads(1, 1, 1)]
35+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
36+
{
37+
// Test max() with x < y
38+
{
39+
dpfloat dpx = dpfloat(2.0, 1.0);
40+
dpfloat dpy = dpfloat(5.0, -2.0);
41+
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
42+
outputBuffer[0] = res.p; // Expect: 5.000000
43+
outputBuffer[1] = res.d; // Expect: -2.000000
44+
}
45+
46+
// Test max() with x == y
47+
{
48+
dpfloat dpx = dpfloat(3.0, 1.0);
49+
dpfloat dpy = dpfloat(3.0, -2.0);
50+
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
51+
outputBuffer[2] = res.p; // Expect: 3.000000
52+
outputBuffer[3] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
53+
}
54+
55+
// Test min() with x > y
56+
{
57+
dpfloat dpx = dpfloat(5.0, 1.0);
58+
dpfloat dpy = dpfloat(2.0, -2.0);
59+
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
60+
outputBuffer[4] = res.p; // Expect: 2.000000
61+
outputBuffer[5] = res.d; // Expect: -2.000000
62+
}
63+
64+
// Test min() with x == y
65+
{
66+
dpfloat dpx = dpfloat(3.0, 1.0);
67+
dpfloat dpy = dpfloat(3.0, -2.0);
68+
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
69+
outputBuffer[6] = res.p; // Expect: 3.000000
70+
outputBuffer[7] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
71+
}
72+
73+
// Test backward-mode max() with x == y
74+
{
75+
dpfloat dpx = dpfloat(3.0, 0.0);
76+
dpfloat dpy = dpfloat(3.0, 0.0);
77+
__bwd_diff(diffMax)(dpx, dpy, 1.0);
78+
outputBuffer[8] = dpx.d; // Expect: 0.500000 (half of gradient)
79+
outputBuffer[9] = dpy.d; // Expect: 0.500000 (half of gradient)
80+
}
81+
82+
// Test backward-mode min() with x == y
83+
{
84+
dpfloat dpx = dpfloat(3.0, 0.0);
85+
dpfloat dpy = dpfloat(3.0, 0.0);
86+
__bwd_diff(diffMin)(dpx, dpy, 1.0);
87+
outputBuffer[10] = dpx.d; // Expect: 0.500000 (half of gradient)
88+
outputBuffer[11] = dpy.d; // Expect: 0.500000 (half of gradient)
89+
}
90+
91+
// Test vector max() with x == y
92+
{
93+
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
94+
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
95+
dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy);
96+
outputBuffer[12] = res.p[0]; // Expect: 3.000000
97+
outputBuffer[13] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
98+
outputBuffer[14] = res.p[1]; // Expect: 4.000000
99+
outputBuffer[15] = res.d[1]; // Expect: 2.000000
100+
}
101+
102+
// Test vector min() with x == y
103+
{
104+
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
105+
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
106+
dpfloat2 res = __fwd_diff(diffMin)(dpx, dpy);
107+
outputBuffer[16] = res.p[0]; // Expect: 3.000000
108+
outputBuffer[17] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
109+
outputBuffer[18] = res.p[1]; // Expect: 2.000000
110+
outputBuffer[19] = res.d[1]; // Expect: -3.000000
111+
}
112+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
type: float
2+
5.000000
3+
-2.000000
4+
3.000000
5+
-0.500000
6+
2.000000
7+
-2.000000
8+
3.000000
9+
-0.500000
10+
0.500000
11+
0.500000
12+
0.500000
13+
0.500000
14+
3.000000
15+
-0.500000
16+
4.000000
17+
2.000000
18+
3.000000
19+
-0.500000
20+
2.000000
21+
-3.000000

tests/autodiff-dstdlib/dstdlib-max.slang

-52
This file was deleted.

tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt

-11
This file was deleted.

0 commit comments

Comments
 (0)