|
| 1 | +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none |
| 2 | + |
| 3 | +// This test just needs to compile without crashing. |
| 4 | + |
| 5 | +// CUDA: __device__ void s_bwd_myKernel_0 |
| 6 | + |
| 7 | +[Differentiable] |
| 8 | +uint3 load_foo(int32_t g_idx, TensorView<int32_t> indices) |
| 9 | +{ |
| 10 | + return { |
| 11 | + indices[uint2(g_idx, 0)], |
| 12 | + indices[uint2(g_idx, 1)], |
| 13 | + indices[uint2(g_idx, 2)] |
| 14 | + }; |
| 15 | +} |
| 16 | + |
| 17 | +[Differentiable] |
| 18 | +Triangle load_triangle(DiffTensorView vertices, uint3 index_set) |
| 19 | +{ |
| 20 | + float3[3] triangle = { |
| 21 | + read_t3_float3(index_set.x, vertices), |
| 22 | + read_t3_float3(index_set.y, vertices), |
| 23 | + read_t3_float3(index_set.z, vertices) |
| 24 | + }; |
| 25 | + return { triangle }; |
| 26 | +} |
| 27 | + |
| 28 | +struct Triangle : IDifferentiable |
| 29 | +{ |
| 30 | + float3[3] verts; |
| 31 | +} |
| 32 | + |
| 33 | +[Differentiable] |
| 34 | +float3 read_t3_float3(uint32_t idx, DiffTensorView t3) |
| 35 | +{ |
| 36 | + return float3(t3[uint2(idx, 0)], |
| 37 | + t3[uint2(idx, 1)], |
| 38 | + t3[uint2(idx, 2)]); |
| 39 | +} |
| 40 | + |
| 41 | +[Differentiable] TriangleFoo load_triangle_foo(int32_t g_idx, DiffTensorView vertices, TensorView<int32_t> indices, DiffTensorView vertex_color) |
| 42 | +{ |
| 43 | + uint3 index_set = load_foo(g_idx, indices); |
| 44 | + Triangle triangle = load_triangle(vertices, index_set); |
| 45 | + float3 c0 = read_t3_float3(index_set.x, vertex_color); |
| 46 | + float3 c1 = read_t3_float3(index_set.y, vertex_color); |
| 47 | + float3 c2 = read_t3_float3(index_set.z, vertex_color); |
| 48 | + TriangleFoo result = { triangle, 0.f, {c0, c1, c2} }; |
| 49 | + return result; |
| 50 | +} |
| 51 | + |
| 52 | +struct TriangleFoo : IDifferentiable |
| 53 | +{ |
| 54 | + Triangle triangle; |
| 55 | + float density; |
| 56 | + float3 vertex_color[3]; |
| 57 | +}; |
| 58 | + |
| 59 | +[AutoPyBindCUDA] |
| 60 | +[Differentiable] |
| 61 | +[CudaKernel] |
| 62 | +void myKernel(DiffTensorView vertices, TensorView<int32_t> indices, DiffTensorView vertex_color) |
| 63 | +{ |
| 64 | + if (cudaThreadIdx().x > 0) |
| 65 | + return; |
| 66 | + |
| 67 | + TriangleFoo.Differential dp_g = TriangleFoo.dzero(); |
| 68 | + |
| 69 | + bwd_diff(load_triangle_foo)(cudaThreadIdx().x, vertices, indices, vertex_color, dp_g); |
| 70 | +} |
0 commit comments