|
| 1 | +## Problem |
| 2 | +Our current `IDifferentiable` system has some flaws. It works fine for value types, since we can assume that every input gets a corresponding output or 'return' value. It works poorly for buffer/pointer types, since we don't 'return' a buffer, but simply want the getters/setters to be differentiable, and the resulting type to have a second buffer/pointer for the differential data. |
| 3 | + |
| 4 | +Here's a demonstrative example with our current codebase when we use value types (like `float`) |
| 5 | +```csharp |
| 6 | +[Differentiable] |
| 7 | +float add(float a, float b) |
| 8 | +{ |
| 9 | + return a + b; |
| 10 | +} |
| 11 | + |
| 12 | +// Synthesized derivative: |
| 13 | +[Differentiable] |
| 14 | +void s_bwd_add(DifferentialPair<float> dpa, DifferentialPair<float> dpb, float.Differential d_out) |
| 15 | +{ |
| 16 | + // A backward derivative method is currently responsible for 'setting' the differential values. |
| 17 | + dpa = DifferentialPair<float>(dpa.p, d_out); |
| 18 | + dpb = DifferentialPair<float>(dpb.p, d_out); |
| 19 | +} |
| 20 | +``` |
| 21 | + |
| 22 | +Unfortunately, this makes little sense if we decide to use buffer or pointer types: |
| 23 | +```csharp |
| 24 | +struct DiffPtr<T> : IDifferentiable |
| 25 | +{ |
| 26 | + StructuredBuffer<T> bufferRef; |
| 27 | + uint64 offset; |
| 28 | + |
| 29 | + [Differentiable] T get() { ... } |
| 30 | + [Differentiable] void set(T t) { ... } |
| 31 | + /* |
| 32 | + Problem 1: |
| 33 | + We use custom derivatives for get() and set() to backprop and |
| 34 | + read gradients. If DiffPtr<T> is differentiable, then get() and |
| 35 | + set() need to operate on the *pair* type and not this struct type. |
| 36 | + There is no proper way to do this currently. |
| 37 | + */ |
| 38 | +}; |
| 39 | + |
| 40 | +[Differentiable] |
| 41 | +void add(DiffPtr<float> a, DiffPtr<float> b, DiffPtr<float> output) |
| 42 | +{ |
| 43 | + output.set(a.get() + b.get()); |
| 44 | +} |
| 45 | + |
| 46 | +// Synthesized derivative: |
| 47 | +[Differentiable] |
| 48 | +void s_bwd_add( |
| 49 | + inout DifferentialPair<DiffPtr<float>> a, |
| 50 | + inout DifferentialPair<DiffPtr<float>> b, |
| 51 | + inout DifferentialPair<DiffPtr<float>> output) |
| 52 | +{ |
| 53 | + /* |
| 54 | + Problem 2: |
| 55 | + |
| 56 | + Current backward mode semantics require that the method assume that the differentials |
| 57 | + a.d and b.d are empty/zero, and it is the backward method's job to populate the result. |
| 58 | + |
| 59 | + It doesn't make sense to 'set' the differential part since it is a buffer ref. |
| 60 | + Rather, we want the user to provide the differential pointer, and use custom derivatives of |
| 61 | + the getters/setters to propagate derivatives. |
| 62 | + |
| 63 | + This also means methods like dzero(), dadd() and dmul() make no sense |
| 64 | + in the context of pointer types. They cannot be initialized within a derivative method. |
| 65 | + */ |
| 66 | +} |
| 67 | + |
| 68 | +``` |
| 69 | + |
| 70 | +## Workarounds |
| 71 | +At the moment the primary workaround is to use a **non-differentiable buffer type** with differentiable methods, and always initialize the object with two pointers for both the primal and differential buffers. This is how our `DiffTensorView<T>` object works. |
| 72 | +Unfortunately, this is a rather hacky workaround with several drawbacks: |
| 73 | +1. `DiffTensorView<T>` does not conform to `IDifferentiable`, but is used for derivatives. This makes our type system less useful as checks for `is_subtype` from applications using reflection need workarounds to account for corner cases like these. |
| 74 | +2. `DiffTensorView<T>` always has two buffer pointers even when used in non-differentiable methods. This is extra data in the struct, and potentially extra tensor allocations (we explicitly handle this case in `slangtorch` by leaving the diff part uninitialized if a primal method is invoked) |
| 75 | +3. Higher-order derivatives don't work well with this workaround. Differentiating a method twice needs a set of 4 pointers, but we need to account for this ahead of time by using new types like `DiffDiffTensorView` that worsens the problem of carrying around extra data where its not required. |
| 76 | + |
| 77 | + |
| 78 | +## Solution |
| 79 | + |
| 80 | +We'll need to make the following 4 additions/changes: |
| 81 | +### 1. `[deriv_method]` function decorator. |
| 82 | +Intended for easy definition of custom derivatives for struct methods. It has the following properties: |
| 83 | +1. Accesses to `this` within `[deriv_method]` are differential pairs. |
| 84 | +2. Methods decorated with `[deriv_method]` cannot be called as regular methods (they can still be explicitly invoked with `bwd_diff(obj.method)`), and do not show up in the auto-complete list. |
| 85 | + |
| 86 | +See the next section for example uses of `[deriv_method]`. |
| 87 | + |
| 88 | +### 2. Split `IDifferentiable` interface: `IDifferentiableValueType` and `IDifferentiablePtrType` |
| 89 | +This approach moves away from "type-driven" derivative semantics and towards more "function-driven" derivative semantics. |
| 90 | +We no longer have a `dadd` , `dzero`, `dmul` etc.. we use default initialization instead of `dzero` and the backward derivative of the `use` method for `dadd` |
| 91 | + |
| 92 | +Further, `IDifferentiablePtrType` types don't have any of these properties. They do not need a way to 'add', and it is especially important that there is no default initializer. We never want the compiler to be able to create a new object of `IDifferentiablePtrType` since we want to get the user-provided pointers. |
| 93 | + |
| 94 | +Additionally, we can use `IDifferentiableValueType` as the current `IDifferentiable` for backwards compatibility (it should just work in 95% of cases, since no one really defines dadd/dzero/dmul explicitly anyway) |
| 95 | + |
| 96 | +Here's the new set of base interfaces: |
| 97 | +```csharp |
| 98 | +interface __IDifferentiableBase { } // Helper type for our implementation. |
| 99 | +interface IDifferentiableValueType : __IDifferentiableBase |
| 100 | +{ |
| 101 | + associatedtype Differential : IDifferentiableValueType & IDefaultInitializable; |
| 102 | + [Differentiable] This use(); // auto-synthesized |
| 103 | +} |
| 104 | + |
| 105 | +interface IDifferentiablePtrType : __IDifferentiableBase |
| 106 | +{ |
| 107 | + associatedtype Differential : IDifferentiablePtrType; |
| 108 | +} |
| 109 | + |
| 110 | +``` |
| 111 | + |
| 112 | +Some extras in stdlib allow us to constrain the diffpair type for things like `IArithmetic` |
| 113 | +```csharp |
| 114 | +// --- STDLIB EXTRAS --- |
| 115 | +
|
| 116 | +interface ISelfDifferentiableValueType : IDifferentiableValueType |
| 117 | +{ |
| 118 | + // Force arithmetic types to be a differential pair of the same two types. |
| 119 | + // Make it simple to define derivatives of arithmetic operations. |
| 120 | + // |
| 121 | + associatedtype Differential : This; |
| 122 | +} |
| 123 | + |
| 124 | +extension IFloat : ISelfDifferentiableValueType |
| 125 | +{ } |
| 126 | + |
| 127 | +extension float |
| 128 | +{ |
| 129 | + // trivial auto-synthesis (maybe we even prevent the user from overriding this) |
| 130 | + float use() { return this; } |
| 131 | + |
| 132 | + // trivial auto-synthesis (maybe we even prevent the user from overriding this). |
| 133 | + [ForwardDerivativeOf(use)] |
| 134 | + [deriv_method] void use_fwd() { return this; } |
| 135 | + |
| 136 | + // auto-synthesized if necessary by invoking the use_bwd for all fields. |
| 137 | + // we need to provide implementation for 'leaf' types. |
| 138 | + [BackwardDerivativeOf(use)] |
| 139 | + [deriv_method] [mutating] void use_bwd(float d) { this.d += d; } |
| 140 | +} |
| 141 | + |
| 142 | +// The new system lets us define differentiable pointers easily. |
| 143 | +// IDifferentiablePtrType'd values are simply treated as references, so they can be freely |
| 144 | +// duplicated without requiring a `use()` for correctness. |
| 145 | +// |
| 146 | +struct DPtr<T : IDifferentiableValueType> : IDifferentiablePtrType |
| 147 | +{ |
| 148 | + typealias Differential = DPtr<T.Differential>; |
| 149 | + |
| 150 | + Buffer<T> buffer; |
| 151 | + uint64 offset; |
| 152 | + |
| 153 | + [BackwardDerivative(get_bwd)] |
| 154 | + [BackwardDerivative(get_fwd)] |
| 155 | + T get() { return this.buffer[offset]; } |
| 156 | + |
| 157 | + [deriv_method] DifferentialPair<T> get_fwd() |
| 158 | + { |
| 159 | + return diffPair(this.p.buffer[offset], this.d().buffer[offset]); |
| 160 | + } |
| 161 | + |
| 162 | + [deriv_method] void get_bwd(Differential d) |
| 163 | + { |
| 164 | + return this.d.InterlockedAdd(offset, d); |
| 165 | + } |
| 166 | + |
| 167 | + DPtr<T> operator+(uint o) { return DPtr<T>{buffer, offset + o}; } |
| 168 | +} |
| 169 | + |
| 170 | +// Or we can define a fancier differentiable pointer that does a hashgrid |
| 171 | +struct DHashGridPtr<T : IDifferentiableValueType, let N: int> : IDifferentiablePtrType |
| 172 | +{ |
| 173 | + typealias Differential = DPtr<T.Differential>; |
| 174 | + |
| 175 | + Buffer<T> buffer; |
| 176 | + uint64 offset; |
| 177 | + |
| 178 | + [BackwardDerivative(get_bwd)] |
| 179 | + [BackwardDerivative(get_fwd)] |
| 180 | + T get() { return this.buffer[offset]; } |
| 181 | + |
| 182 | + [deriv_method] DifferentialPair<T> get_fwd() |
| 183 | + { |
| 184 | + return diffPair(this.p().buffer[offset], this.d().buffer[offset]); |
| 185 | + } |
| 186 | + |
| 187 | + [deriv_method] void get_bwd(Differential d) |
| 188 | + { |
| 189 | + return this.d().InterlockedAdd(offset * N + hash(get_thread_id()), d); |
| 190 | + } |
| 191 | +} |
| 192 | +``` |
| 193 | + |
| 194 | +### 3. Every time we 'reuse' an object that conforms to `IDifferentiableValueType`, we split it with `use()` , and we use `__init__()` where necessary to initialize an accumulator. |
| 195 | +Example: |
| 196 | +```csharp |
| 197 | +float f(float a) |
| 198 | +{ |
| 199 | + add(a, a); |
| 200 | +} |
| 201 | +float add(float a, float b) |
| 202 | +{ |
| 203 | + return a + b; |
| 204 | +} |
| 205 | + |
| 206 | +// Synthesized derivatives |
| 207 | +void add_bwd(inout DiffPair<float> dpa, inout DiffPair<float> dpb, float d_out) |
| 208 | +{ |
| 209 | + dpa = diffPair(dpa.p, d_out); |
| 210 | + dpb = diffPair(dpb.p, d_out); |
| 211 | +} |
| 212 | + |
| 213 | +// Preprocessed-f (before derivative generation) |
| 214 | +float f_with_use_expansion(float a) |
| 215 | +{ |
| 216 | + DiffPair<float> a_extra = a.use(); |
| 217 | + return add(a, a_extra); |
| 218 | +} |
| 219 | + |
| 220 | +// After fwd-mode: |
| 221 | +DiffPair<float> f_fwd(DiffPair<float> dpa) |
| 222 | +{ |
| 223 | + DiffPair<float> dpa_extra = dpa.use_fwd(); |
| 224 | + return add_fwd(a, a_extra_fwd); |
| 225 | +} |
| 226 | + |
| 227 | + |
| 228 | +// bwd-mode: |
| 229 | +void f_bwd(inout DiffPair<float> dpa, float d_out) |
| 230 | +{ |
| 231 | + // fwd-pass |
| 232 | +
|
| 233 | + // split |
| 234 | + DiffPair<float> dpa_extra = dpa.use_fwd(); |
| 235 | + // ------- |
| 236 | + |
| 237 | + // bwd-pass |
| 238 | + dpa_extra_bwd = DiffPair<float>(dpa_extra.p, float.Differential::__init__()); |
| 239 | + add_bwd(dpa, dpa_extra, d_out); |
| 240 | + |
| 241 | + // merge |
| 242 | + dpa.use_bwd(dpa_extra); |
| 243 | +} |
| 244 | +``` |
| 245 | + |
| 246 | +### 4. Objects that conform to `IDifferentiablePtrType` are used without splitting. They are simply not 'transposed' at all, because there is nothing to transpose. The fwd-mode pair is used as is. |
| 247 | +Here's the same example above, but with the `DPtr` type defined above. |
| 248 | + |
| 249 | +```csharp |
| 250 | +void f(DPtr<float> a, DPtr<float> output) |
| 251 | +{ |
| 252 | + add(a, a, output); |
| 253 | +} |
| 254 | + |
| 255 | +void add(DPtr<float> a, DPtr<float> b, DPtr<float> output) |
| 256 | +{ |
| 257 | + output.set(a.get() + b.get()); |
| 258 | +} |
| 259 | + |
| 260 | +// Synthesized derivatives |
| 261 | +// (note: no inout req'd for IDifferentiablePtrType) |
| 262 | +// important difference is that `ptr` types don't get transposed, only |
| 263 | +// methods on the objects are. |
| 264 | +// they DO NOT have a default initializer (the user must supply the differential part) |
| 265 | +void add_bwd( |
| 266 | + DifferentialPair<DPtr<float>> dpa, |
| 267 | + DifferentialPair<DPtr<float>> dpb, |
| 268 | + DifferentialPair<DPtr<float>> output) |
| 269 | +{ |
| 270 | + // forward pass. |
| 271 | + var a_p = dpa.p.get(); |
| 272 | + var b_p = dpb.p.get(); |
| 273 | + // ---- |
| 274 | + |
| 275 | + // backward pass. |
| 276 | + float.Differential d_val = DPtr<float>::set_bwd(output); // set_bwd works on the entire pair. |
| 277 | + DifferentialPair<float> a_get_bwd = diffPair(a_p, float.Differential::__init__()); |
| 278 | + DifferentialPair<float> b_get_bwd = diffPair(b_p, float.Differential::__init__()); |
| 279 | + operator_float_add_bwd(a_get_result_bwd, b_get_result_bwd, d_val); |
| 280 | + DPtr<float>::get_bwd(dpa); |
| 281 | + DPtr<float>::get_bwd(dpb); |
| 282 | +} |
| 283 | +``` |
0 commit comments