|
| 1 | +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type |
| 2 | +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type |
| 3 | + |
| 4 | +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer |
| 5 | +RWStructuredBuffer<float> outputBuffer; |
| 6 | + |
| 7 | +[Differentiable] |
| 8 | +[PreferRecompute] |
| 9 | +float3 diffRayIntersectTriangle(no_diff float3 rayOrigin, float3 rayDir, no_diff float3 p[3]) |
| 10 | +{ |
| 11 | + float3 e1 = p[1] - p[0]; |
| 12 | + float3 e2 = p[2] - p[0]; |
| 13 | + float3 pVec = cross(rayDir, e2); |
| 14 | + float divisor = dot(pVec, e1); |
| 15 | + float3 s = rayOrigin - p[0]; |
| 16 | + float u = dot(s, pVec) / divisor; |
| 17 | + float3 qVec = cross(s, e1); |
| 18 | + float v = dot(rayDir, qVec) / divisor; |
| 19 | + float t = dot(e2, qVec) / divisor; |
| 20 | + return float3(u, v, t); |
| 21 | +} |
| 22 | + |
| 23 | +[Differentiable] |
| 24 | +[PreferRecompute] |
| 25 | +float3 diffRayIntersectTriangle2(no_diff float3 rayOrigin, float3 rayTarget, no_diff float3 p[3]) |
| 26 | +{ |
| 27 | + float3 rayDir = normalize(rayTarget - rayOrigin); |
| 28 | + float3 uvt = diffRayIntersectTriangle(rayOrigin, rayDir, p); |
| 29 | + float3 result = (1.f - uvt.x - uvt.y) * p[0] + uvt.x * p[1] + uvt.y * p[2]; |
| 30 | + return result; |
| 31 | +} |
| 32 | + |
| 33 | +[numthreads(1, 1, 1)] |
| 34 | +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) |
| 35 | +{ |
| 36 | + float3 shadePos = float3(0.674034, 0.0, 0.123171); |
| 37 | + float3 targetPos = float3(0.5, 0.2, -1.0); |
| 38 | + float3 triPos[3] = { float3(0.0, 1.0, -1.0), float3(1.0, 1.0, 0.0), float3(0.0, 1.0, 0.0) }; |
| 39 | + |
| 40 | + // Forward-mode |
| 41 | + DifferentialPair<float3> dpIsectPos = fwd_diff(diffRayIntersectTriangle2)( |
| 42 | + shadePos, |
| 43 | + DifferentialPair<float3>(targetPos, float3(1.0, 0.0, 0.0)), |
| 44 | + triPos |
| 45 | + ); |
| 46 | + |
| 47 | + outputBuffer[0] = dpIsectPos.d[0]; // Expect: 5.0 |
| 48 | + outputBuffer[1] = dpIsectPos.d[1]; // Expect: 0.0 |
| 49 | + outputBuffer[2] = dpIsectPos.d[2]; // Expect: 0.0 |
| 50 | + |
| 51 | + // Reverse-mode |
| 52 | + DifferentialPair<float3> dpTargetPos = diffPair(targetPos, float3(0.f)); |
| 53 | + bwd_diff(diffRayIntersectTriangle2)( |
| 54 | + shadePos, |
| 55 | + dpTargetPos, |
| 56 | + triPos, |
| 57 | + float3(1.f, 1.f, 1.f) |
| 58 | + ); |
| 59 | + |
| 60 | + outputBuffer[3] = dpTargetPos.d[0]; // Expect: 5.0 |
| 61 | + outputBuffer[4] = dpTargetPos.d[1]; // Expect: 32.4301 |
| 62 | + outputBuffer[5] = dpTargetPos.d[2]; // Expect: 5.0 |
| 63 | +} |
0 commit comments