|
| 1 | + |
| 2 | +[Differentiable] |
| 3 | +float sumOfSquares(float x, float y, no_diff float4* test) |
| 4 | +{ |
| 5 | + return x * x + y * y * (test->x + test->y + test->z); |
| 6 | +} |
| 7 | + |
| 8 | +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly |
| 9 | + |
| 10 | +//TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0], stride=4) |
| 11 | +uniform float* ptr; |
| 12 | + |
| 13 | +//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer |
| 14 | +RWStructuredBuffer<float> outputBuffer; |
| 15 | + |
| 16 | +[shader("compute")] |
| 17 | +[numthreads(1, 1, 1)] |
| 18 | +void computeMain() |
| 19 | +{ |
| 20 | + float4* testPtr = (float4*)ptr; |
| 21 | + |
| 22 | + let result = sumOfSquares(2.0, 3.0, testPtr); |
| 23 | + |
| 24 | + // Use forward differentiation to compute the gradient of the output w.r.t. x only. |
| 25 | + let diffX = fwd_diff(sumOfSquares)(diffPair(2.0, 1.0), diffPair(3.0, 0.0), testPtr); |
| 26 | + |
| 27 | + // Create a differentiable pair to pass in the primal value and to receive the gradient. |
| 28 | + var dpX = diffPair(2.0); |
| 29 | + var dpY = diffPair(3.0); |
| 30 | + |
| 31 | + // Propagate the gradient of the output (1.0f) to the input parameters. |
| 32 | + bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0); |
| 33 | + |
| 34 | + outputBuffer[0] = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58 |
| 35 | + outputBuffer[1] = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4 |
| 36 | + outputBuffer[2] = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58 |
| 37 | + outputBuffer[3] = dpX.d; // 2*x = 4 |
| 38 | + |
| 39 | + outputBuffer[4] = dpY.d; // 2*y * (1 + 2 +3) = 36 |
| 40 | +} |
0 commit comments