|
1 | 1 | //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
|
2 | 2 | //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
|
3 | 3 |
|
4 |
| -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer |
| 4 | +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer |
5 | 5 | RWStructuredBuffer<float> outputBuffer;
|
6 | 6 |
|
7 | 7 | typedef DifferentialPair<float> dpfloat;
|
@@ -178,4 +178,46 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
|
178 | 178 | outputBuffer[28] = dpmax.d.y; // Expected: 0.0
|
179 | 179 | outputBuffer[29] = dpmax.d.z; // Expected: 0.3
|
180 | 180 | }
|
| 181 | + |
| 182 | + // New tests: Forward-mode tests for derivative propagation at the edges with clamp(x, 0, 1) |
| 183 | + { |
| 184 | + // Lower edge: x exactly = 0 |
| 185 | + dpfloat dpx = dpfloat(0.0, 0.4); |
| 186 | + dpfloat dpmin = dpfloat(0.0, 0.8); |
| 187 | + dpfloat dpmax = dpfloat(1.0, 0.5); |
| 188 | + dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax); |
| 189 | + outputBuffer[30] = res.d; // Expected: 0.4 (propagated from x) |
| 190 | + } |
| 191 | + |
| 192 | + { |
| 193 | + // Upper edge: x exactly = 1 |
| 194 | + dpfloat dpx = dpfloat(1.0, 0.7); |
| 195 | + dpfloat dpmin = dpfloat(0.0, 0.8); |
| 196 | + dpfloat dpmax = dpfloat(1.0, 0.9); |
| 197 | + dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax); |
| 198 | + outputBuffer[31] = res.d; // Expected: 0.7 (propagated from x) |
| 199 | + } |
| 200 | + |
| 201 | + // Reverse-mode tests for derivative propagation at the edges with clamp(x, 0, 1) |
| 202 | + { |
| 203 | + // Lower edge: x exactly = 0 |
| 204 | + dpfloat dpx = dpfloat(0.0, 0.0); |
| 205 | + dpfloat dpmin = dpfloat(0.0, 0.0); |
| 206 | + dpfloat dpmax = dpfloat(1.0, 0.0); |
| 207 | + bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0); |
| 208 | + outputBuffer[32] = dpx.d; // Expected: 1.0 (propagated from x) |
| 209 | + outputBuffer[33] = dpmin.d; // Expected: 0.0 |
| 210 | + outputBuffer[34] = dpmax.d; // Expected: 0.0 |
| 211 | + } |
| 212 | + |
| 213 | + { |
| 214 | + // Upper edge: x exactly = 1 |
| 215 | + dpfloat dpx = dpfloat(1.0, 0.0); |
| 216 | + dpfloat dpmin = dpfloat(0.0, 0.0); |
| 217 | + dpfloat dpmax = dpfloat(1.0, 0.0); |
| 218 | + bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0); |
| 219 | + outputBuffer[35] = dpx.d; // Expected: 1.0 (propagated from x) |
| 220 | + outputBuffer[36] = dpmin.d; // Expected: 0.0 |
| 221 | + outputBuffer[37] = dpmax.d; // Expected: 0.0 |
| 222 | + } |
181 | 223 | }
|
0 commit comments