Skip to content

Commit 714ee76

Browse files
Fix crash when swizzling non-differentiable types (shader-slang#6613)
* Fix crash when swizzling non-differentiable types * Update slang-ir-autodiff-fwd.cpp
1 parent 98ff419 commit 714ee76

File tree

2 files changed

+89
-11
lines changed

2 files changed

+89
-11
lines changed

source/slang/slang-ir-autodiff-fwd.cpp

+19-11
Original file line numberDiff line numberDiff line change
@@ -955,20 +955,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
955955
InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
956956
{
957957
IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle);
958-
959958
if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
960959
{
961-
List<IRInst*> swizzleIndices;
962-
for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
963-
swizzleIndices.add(origSwizzle->getElementIndex(ii));
960+
// `diffBase` may exist even if the type is non-differentiable (e.g. IRCall inst that
961+
// creates other differentiable outputs).
962+
//
963+
// We'll check to see if we can get a differential for the type in order to determine
964+
// whether to generate a differential swizzle inst.
965+
//
966+
if (auto diffType = differentiateType(builder, primalSwizzle->getDataType()))
967+
{
968+
List<IRInst*> swizzleIndices;
969+
for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
970+
swizzleIndices.add(origSwizzle->getElementIndex(ii));
964971

965-
return InstPair(
966-
primalSwizzle,
967-
builder->emitSwizzle(
968-
differentiateType(builder, primalSwizzle->getDataType()),
969-
diffBase,
970-
origSwizzle->getElementCount(),
971-
swizzleIndices.getBuffer()));
972+
return InstPair(
973+
primalSwizzle,
974+
builder->emitSwizzle(
975+
diffType,
976+
diffBase,
977+
origSwizzle->getElementCount(),
978+
swizzleIndices.getBuffer()));
979+
}
972980
}
973981

974982
return InstPair(primalSwizzle, nullptr);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none
2+
3+
// This test just needs to compile without crashing.
4+
5+
// CUDA: __device__ void s_bwd_myKernel_0
6+
7+
[Differentiable]
8+
uint3 load_foo(int32_t g_idx, TensorView<int32_t> indices)
9+
{
10+
return {
11+
indices[uint2(g_idx, 0)],
12+
indices[uint2(g_idx, 1)],
13+
indices[uint2(g_idx, 2)]
14+
};
15+
}
16+
17+
[Differentiable]
18+
Triangle load_triangle(DiffTensorView vertices, uint3 index_set)
19+
{
20+
float3[3] triangle = {
21+
read_t3_float3(index_set.x, vertices),
22+
read_t3_float3(index_set.y, vertices),
23+
read_t3_float3(index_set.z, vertices)
24+
};
25+
return { triangle };
26+
}
27+
28+
struct Triangle : IDifferentiable
29+
{
30+
float3[3] verts;
31+
}
32+
33+
[Differentiable]
34+
float3 read_t3_float3(uint32_t idx, DiffTensorView t3)
35+
{
36+
return float3(t3[uint2(idx, 0)],
37+
t3[uint2(idx, 1)],
38+
t3[uint2(idx, 2)]);
39+
}
40+
41+
[Differentiable] TriangleFoo load_triangle_foo(int32_t g_idx, DiffTensorView vertices, TensorView<int32_t> indices, DiffTensorView vertex_color)
42+
{
43+
uint3 index_set = load_foo(g_idx, indices);
44+
Triangle triangle = load_triangle(vertices, index_set);
45+
float3 c0 = read_t3_float3(index_set.x, vertex_color);
46+
float3 c1 = read_t3_float3(index_set.y, vertex_color);
47+
float3 c2 = read_t3_float3(index_set.z, vertex_color);
48+
TriangleFoo result = { triangle, 0.f, {c0, c1, c2} };
49+
return result;
50+
}
51+
52+
struct TriangleFoo : IDifferentiable
53+
{
54+
Triangle triangle;
55+
float density;
56+
float3 vertex_color[3];
57+
};
58+
59+
[AutoPyBindCUDA]
60+
[Differentiable]
61+
[CudaKernel]
62+
void myKernel(DiffTensorView vertices, TensorView<int32_t> indices, DiffTensorView vertex_color)
63+
{
64+
if (cudaThreadIdx().x > 0)
65+
return;
66+
67+
TriangleFoo.Differential dp_g = TriangleFoo.dzero();
68+
69+
bwd_diff(load_triangle_foo)(cudaThreadIdx().x, vertices, indices, vertex_color, dp_g);
70+
}

0 commit comments

Comments
 (0)