Skip to content

Commit 359e96c

Browse files
Proposal: A simpler and more flexible IDifferentiable system (#4865)
Designed to work for both value types as well as ref/ptr/buffer types Co-authored-by: Yong He <yonghe@outlook.com>
1 parent f9f6a28 commit 359e96c

File tree

1 file changed

+283
-0
lines changed

1 file changed

+283
-0
lines changed
+283
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
## Problem
2+
Our current `IDifferentiable` system has some flaws. It works fine for value types, since we can assume that every input gets a corresponding output or 'return' value. It works poorly for buffer/pointer types, since we don't 'return' a buffer, but simply want the getters/setters to be differentiable, and the resulting type to have a second buffer/pointer for the differential data.
3+
4+
Here's a demonstrative example with our current codebase when we use value types (like `float`)
5+
```csharp
6+
[Differentiable]
7+
float add(float a, float b)
8+
{
9+
return a + b;
10+
}
11+
12+
// Synthesized derivative:
13+
[Differentiable]
14+
void s_bwd_add(DifferentialPair<float> dpa, DifferentialPair<float> dpb, float.Differential d_out)
15+
{
16+
// A backward derivative method is currently responsible for 'setting' the differential values.
17+
dpa = DifferentialPair<float>(dpa.p, d_out);
18+
dpb = DifferentialPair<float>(dpb.p, d_out);
19+
}
20+
```
21+
22+
Unfortunately, this makes little sense if we decide to use buffer or pointer types:
23+
```csharp
24+
struct DiffPtr<T> : IDifferentiable
25+
{
26+
StructuredBuffer<T> bufferRef;
27+
uint64 offset;
28+
29+
[Differentiable] T get() { ... }
30+
[Differentiable] void set(T t) { ... }
31+
/*
32+
Problem 1:
33+
We use custom derivatives for get() and set() to backprop and
34+
read gradients. If DiffPtr<T> is differentiable, then get() and
35+
set() need to operate on the *pair* type and not this struct type.
36+
There is no proper way to do this currently.
37+
*/
38+
};
39+
40+
[Differentiable]
41+
void add(DiffPtr<float> a, DiffPtr<float> b, DiffPtr<float> output)
42+
{
43+
output.set(a.get() + b.get());
44+
}
45+
46+
// Synthesized derivative:
47+
[Differentiable]
48+
void s_bwd_add(
49+
inout DifferentialPair<DiffPtr<float>> a,
50+
inout DifferentialPair<DiffPtr<float>> b,
51+
inout DifferentialPair<DiffPtr<float>> output)
52+
{
53+
/*
54+
Problem 2:
55+
56+
Current backward mode semantics require that the method assume that the differentials
57+
a.d and b.d are empty/zero, and it is the backward method's job to populate the result.
58+
59+
It doesn't make sense to 'set' the differential part since it is a buffer ref.
60+
Rather, we want the user to provide the differential pointer, and use custom derivatives of
61+
the getters/setters to propagate derivatives.
62+
63+
This also means methods like dzero(), dadd() and dmul() make no sense
64+
in the context of pointer types. They cannot be initialized within a derivative method.
65+
*/
66+
}
67+
68+
```
69+
70+
## Workarounds
71+
At the moment the primary workaround is to use a **non-differentiable buffer type** with differentiable methods, and always initialize the object with two pointers for both the primal and differential buffers. This is how our `DiffTensorView<T>` object works.
72+
Unfortunately, this is a rather hacky workaround with several drawbacks:
73+
1. `DiffTensorView<T>` does not conform to `IDifferentiable`, but is used for derivatives. This makes our type system less useful as checks for `is_subtype` from applications using reflection need workarounds to account for corner cases like these.
74+
2. `DiffTensorView<T>` always has two buffer pointers even when used in non-differentiable methods. This is extra data in the struct, and potentially extra tensor allocations (we explicitly handle this case in `slangtorch` by leaving the diff part uninitialized if a primal method is invoked)
75+
3. Higher-order derivatives don't work well with this workaround. Differentiating a method twice needs a set of 4 pointers, but we need to account for this ahead of time by using new types like `DiffDiffTensorView` that worsens the problem of carrying around extra data where its not required.
76+
77+
78+
## Solution
79+
80+
We'll need to make the following 4 additions/changes:
81+
### 1. `[deriv_method]` function decorator.
82+
Intended for easy definition of custom derivatives for struct methods. It has the following properties:
83+
1. Accesses to `this` within `[deriv_method]` are differential pairs.
84+
2. Methods decorated with `[deriv_method]` cannot be called as regular methods (they can still be explicitly invoked with `bwd_diff(obj.method)`), and do not show up in the auto-complete list.
85+
86+
See the next section for example uses of `[deriv_method]`.
87+
88+
### 2. Split `IDifferentiable` interface: `IDifferentiableValueType` and `IDifferentiablePtrType`
89+
This approach moves away from "type-driven" derivative semantics and towards more "function-driven" derivative semantics.
90+
We no longer have a `dadd` , `dzero`, `dmul` etc.. we use default initialization instead of `dzero` and the backward derivative of the `use` method for `dadd`
91+
92+
Further, `IDifferentiablePtrType` types don't have any of these properties. They do not need a way to 'add', and it is especially important that there is no default initializer. We never want the compiler to be able to create a new object of `IDifferentiablePtrType` since we want to get the user-provided pointers.
93+
94+
Additionally, we can use `IDifferentiableValueType` as the current `IDifferentiable` for backwards compatibility (it should just work in 95% of cases, since no one really defines dadd/dzero/dmul explicitly anyway)
95+
96+
Here's the new set of base interfaces:
97+
```csharp
98+
interface __IDifferentiableBase { } // Helper type for our implementation.
99+
interface IDifferentiableValueType : __IDifferentiableBase
100+
{
101+
associatedtype Differential : IDifferentiableValueType & IDefaultInitializable;
102+
[Differentiable] This use(); // auto-synthesized
103+
}
104+
105+
interface IDifferentiablePtrType : __IDifferentiableBase
106+
{
107+
associatedtype Differential : IDifferentiablePtrType;
108+
}
109+
110+
```
111+
112+
Some extras in stdlib allow us to constrain the diffpair type for things like `IArithmetic`
113+
```csharp
114+
// --- STDLIB EXTRAS ---
115+
116+
interface ISelfDifferentiableValueType : IDifferentiableValueType
117+
{
118+
// Force arithmetic types to be a differential pair of the same two types.
119+
// Make it simple to define derivatives of arithmetic operations.
120+
//
121+
associatedtype Differential : This;
122+
}
123+
124+
extension IFloat : ISelfDifferentiableValueType
125+
{ }
126+
127+
extension float
128+
{
129+
// trivial auto-synthesis (maybe we even prevent the user from overriding this)
130+
float use() { return this; }
131+
132+
// trivial auto-synthesis (maybe we even prevent the user from overriding this).
133+
[ForwardDerivativeOf(use)]
134+
[deriv_method] void use_fwd() { return this; }
135+
136+
// auto-synthesized if necessary by invoking the use_bwd for all fields.
137+
// we need to provide implementation for 'leaf' types.
138+
[BackwardDerivativeOf(use)]
139+
[deriv_method] [mutating] void use_bwd(float d) { this.d += d; }
140+
}
141+
142+
// The new system lets us define differentiable pointers easily.
143+
// IDifferentiablePtrType'd values are simply treated as references, so they can be freely
144+
// duplicated without requiring a `use()` for correctness.
145+
//
146+
struct DPtr<T : IDifferentiableValueType> : IDifferentiablePtrType
147+
{
148+
typealias Differential = DPtr<T.Differential>;
149+
150+
Buffer<T> buffer;
151+
uint64 offset;
152+
153+
[BackwardDerivative(get_bwd)]
154+
[BackwardDerivative(get_fwd)]
155+
T get() { return this.buffer[offset]; }
156+
157+
[deriv_method] DifferentialPair<T> get_fwd()
158+
{
159+
return diffPair(this.p.buffer[offset], this.d().buffer[offset]);
160+
}
161+
162+
[deriv_method] void get_bwd(Differential d)
163+
{
164+
return this.d.InterlockedAdd(offset, d);
165+
}
166+
167+
DPtr<T> operator+(uint o) { return DPtr<T>{buffer, offset + o}; }
168+
}
169+
170+
// Or we can define a fancier differentiable pointer that does a hashgrid
171+
struct DHashGridPtr<T : IDifferentiableValueType, let N: int> : IDifferentiablePtrType
172+
{
173+
typealias Differential = DPtr<T.Differential>;
174+
175+
Buffer<T> buffer;
176+
uint64 offset;
177+
178+
[BackwardDerivative(get_bwd)]
179+
[BackwardDerivative(get_fwd)]
180+
T get() { return this.buffer[offset]; }
181+
182+
[deriv_method] DifferentialPair<T> get_fwd()
183+
{
184+
return diffPair(this.p().buffer[offset], this.d().buffer[offset]);
185+
}
186+
187+
[deriv_method] void get_bwd(Differential d)
188+
{
189+
return this.d().InterlockedAdd(offset * N + hash(get_thread_id()), d);
190+
}
191+
}
192+
```
193+
194+
### 3. Every time we 'reuse' an object that conforms to `IDifferentiableValueType`, we split it with `use()` , and we use `__init__()` where necessary to initialize an accumulator.
195+
Example:
196+
```csharp
197+
float f(float a)
198+
{
199+
add(a, a);
200+
}
201+
float add(float a, float b)
202+
{
203+
return a + b;
204+
}
205+
206+
// Synthesized derivatives
207+
void add_bwd(inout DiffPair<float> dpa, inout DiffPair<float> dpb, float d_out)
208+
{
209+
dpa = diffPair(dpa.p, d_out);
210+
dpb = diffPair(dpb.p, d_out);
211+
}
212+
213+
// Preprocessed-f (before derivative generation)
214+
float f_with_use_expansion(float a)
215+
{
216+
DiffPair<float> a_extra = a.use();
217+
return add(a, a_extra);
218+
}
219+
220+
// After fwd-mode:
221+
DiffPair<float> f_fwd(DiffPair<float> dpa)
222+
{
223+
DiffPair<float> dpa_extra = dpa.use_fwd();
224+
return add_fwd(a, a_extra_fwd);
225+
}
226+
227+
228+
// bwd-mode:
229+
void f_bwd(inout DiffPair<float> dpa, float d_out)
230+
{
231+
// fwd-pass
232+
233+
// split
234+
DiffPair<float> dpa_extra = dpa.use_fwd();
235+
// -------
236+
237+
// bwd-pass
238+
dpa_extra_bwd = DiffPair<float>(dpa_extra.p, float.Differential::__init__());
239+
add_bwd(dpa, dpa_extra, d_out);
240+
241+
// merge
242+
dpa.use_bwd(dpa_extra);
243+
}
244+
```
245+
246+
### 4. Objects that conform to `IDifferentiablePtrType` are used without splitting. They are simply not 'transposed' at all, because there is nothing to transpose. The fwd-mode pair is used as is.
247+
Here's the same example above, but with the `DPtr` type defined above.
248+
249+
```csharp
250+
void f(DPtr<float> a, DPtr<float> output)
251+
{
252+
add(a, a, output);
253+
}
254+
255+
void add(DPtr<float> a, DPtr<float> b, DPtr<float> output)
256+
{
257+
output.set(a.get() + b.get());
258+
}
259+
260+
// Synthesized derivatives
261+
// (note: no inout req'd for IDifferentiablePtrType)
262+
// important difference is that `ptr` types don't get transposed, only
263+
// methods on the objects are.
264+
// they DO NOT have a default initializer (the user must supply the differential part)
265+
void add_bwd(
266+
DifferentialPair<DPtr<float>> dpa,
267+
DifferentialPair<DPtr<float>> dpb,
268+
DifferentialPair<DPtr<float>> output)
269+
{
270+
// forward pass.
271+
var a_p = dpa.p.get();
272+
var b_p = dpb.p.get();
273+
// ----
274+
275+
// backward pass.
276+
float.Differential d_val = DPtr<float>::set_bwd(output); // set_bwd works on the entire pair.
277+
DifferentialPair<float> a_get_bwd = diffPair(a_p, float.Differential::__init__());
278+
DifferentialPair<float> b_get_bwd = diffPair(b_p, float.Differential::__init__());
279+
operator_float_add_bwd(a_get_result_bwd, b_get_result_bwd, d_val);
280+
DPtr<float>::get_bwd(dpa);
281+
DPtr<float>::get_bwd(dpb);
282+
}
283+
```

0 commit comments

Comments
 (0)