Skip to content

Commit 0d92068

Browse files
winmadcsyonghe
andauthored
Fix a bug in the forward derivative of cross product (shader-slang#4006)
* Fix a bug in fwd-diff for cross product * Also add a test for the reverse-mode AD --------- Co-authored-by: Yong He <yonghe@outlook.com>
1 parent 9f892c9 commit 0d92068

File tree

3 files changed

+73
-3
lines changed

3 files changed

+73
-3
lines changed

source/slang/diff.meta.slang

+3-3
Original file line numberDiff line numberDiff line change
@@ -1018,17 +1018,17 @@ DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, Diffe
10181018
T aybz = a.p.y * b.p.z;
10191019
T azby = a.p.z * b.p.y;
10201020
T px = aybz - azby;
1021-
T dx = (b.p.z - azby) * a.d.y + (a.p.y - azby) * b.d.z + (aybz - b.p.y) * a.d.z + (aybz - a.p.z) * b.d.y;
1021+
T dx = a.d.y * b.p.z + a.p.y * b.d.z - a.d.z * b.p.y - a.p.z * b.d.y;
10221022

10231023
T azbx = a.p.z * b.p.x;
10241024
T axbz = a.p.x * b.p.z;
10251025
T py = azbx - axbz;
1026-
T dy = (b.p.x - axbz) * a.d.z + (a.p.z - axbz) * b.d.x + (azbx - b.p.z) * a.d.x + (azbx - a.p.x) * b.d.z;
1026+
T dy = a.d.z * b.p.x + a.p.z * b.d.x - a.d.x * b.p.z - a.p.x * b.d.z;
10271027

10281028
T axby = a.p.x * b.p.y;
10291029
T aybx = a.p.y * b.p.x;
10301030
T pz = axby - aybx;
1031-
T dz = (b.p.y - aybx) * a.d.x + (a.p.x - aybx) * b.d.y + (axby - b.p.x) * a.d.y + (axby - a.p.y) * b.d.x;
1031+
T dz = a.d.x * b.p.y + a.p.x * b.d.y - a.d.y * b.p.x - a.p.y * b.d.x;
10321032

10331033
return DifferentialPair<vector<T, 3>>(vector<T, 3>(px, py, pz), vector<T, 3>.Differential(dx, dy, dz));
10341034
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
[Differentiable]
8+
[PreferRecompute]
9+
float3 diffRayIntersectTriangle(no_diff float3 rayOrigin, float3 rayDir, no_diff float3 p[3])
10+
{
11+
float3 e1 = p[1] - p[0];
12+
float3 e2 = p[2] - p[0];
13+
float3 pVec = cross(rayDir, e2);
14+
float divisor = dot(pVec, e1);
15+
float3 s = rayOrigin - p[0];
16+
float u = dot(s, pVec) / divisor;
17+
float3 qVec = cross(s, e1);
18+
float v = dot(rayDir, qVec) / divisor;
19+
float t = dot(e2, qVec) / divisor;
20+
return float3(u, v, t);
21+
}
22+
23+
[Differentiable]
24+
[PreferRecompute]
25+
float3 diffRayIntersectTriangle2(no_diff float3 rayOrigin, float3 rayTarget, no_diff float3 p[3])
26+
{
27+
float3 rayDir = normalize(rayTarget - rayOrigin);
28+
float3 uvt = diffRayIntersectTriangle(rayOrigin, rayDir, p);
29+
float3 result = (1.f - uvt.x - uvt.y) * p[0] + uvt.x * p[1] + uvt.y * p[2];
30+
return result;
31+
}
32+
33+
[numthreads(1, 1, 1)]
34+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
35+
{
36+
float3 shadePos = float3(0.674034, 0.0, 0.123171);
37+
float3 targetPos = float3(0.5, 0.2, -1.0);
38+
float3 triPos[3] = { float3(0.0, 1.0, -1.0), float3(1.0, 1.0, 0.0), float3(0.0, 1.0, 0.0) };
39+
40+
// Forward-mode
41+
DifferentialPair<float3> dpIsectPos = fwd_diff(diffRayIntersectTriangle2)(
42+
shadePos,
43+
DifferentialPair<float3>(targetPos, float3(1.0, 0.0, 0.0)),
44+
triPos
45+
);
46+
47+
outputBuffer[0] = dpIsectPos.d[0]; // Expect: 5.0
48+
outputBuffer[1] = dpIsectPos.d[1]; // Expect: 0.0
49+
outputBuffer[2] = dpIsectPos.d[2]; // Expect: 0.0
50+
51+
// Reverse-mode
52+
DifferentialPair<float3> dpTargetPos = diffPair(targetPos, float3(0.f));
53+
bwd_diff(diffRayIntersectTriangle2)(
54+
shadePos,
55+
dpTargetPos,
56+
triPos,
57+
float3(1.f, 1.f, 1.f)
58+
);
59+
60+
outputBuffer[3] = dpTargetPos.d[0]; // Expect: 5.0
61+
outputBuffer[4] = dpTargetPos.d[1]; // Expect: 32.4301
62+
outputBuffer[5] = dpTargetPos.d[2]; // Expect: 5.0
63+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
type: float
2+
5.0
3+
0.0
4+
0.0
5+
5.0
6+
32.43011856
7+
5.0

0 commit comments

Comments
 (0)