forked from shader-slang/slang
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgh-5781.slang
57 lines (44 loc) · 1.26 KB
/
gh-5781.slang
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
//TEST:SIMPLE(filecheck=CHECK): -target spirv
// CHECK: OpEntryPoint
module test;
public enum class MaterialID : uint { invalid = 0xffffffff };
public struct Material : IDifferentiable
{
float x;
}
public struct Hit
{
MaterialID material;
}
public struct Scene
{
StructuredBuffer<Material> materials;
RWStructuredBuffer<Material> grads;
[Differentiable]
Material load(MaterialID id) { return materials[uint(id)]; }
void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; }
[Differentiable, BackwardDerivative(_get_material_bwd)]
public Material get_material(MaterialID id) { return load(id); }
public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); }
[Differentiable]
public Material get_material(Hit hit) { return get_material(hit.material); }
}
[Differentiable]
float trace(const Scene scene, Hit hit)
{
Material m = scene.get_material(hit);
return m.x;
}
[shader("compute")]
void main(
uniform Scene scene,
uniform StructuredBuffer<uint> input,
uniform RWStructuredBuffer<float> output,
uniform RWStructuredBuffer<float> grads
)
{
Hit hit;
hit.material = MaterialID(input[0]);
output[0] = trace(scene, hit);
bwd_diff(trace)(scene, hit, grads[0]);
}