diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index b3a25358c8..9209a903c1 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -136,8 +136,13 @@ Where the backward propagation function $$\mathbb{B}[f_i]$$ takes as input the p The higher order operator $$\mathbb{F}$$ and $$\mathbb{B}$$ represent the operations that converts an original or primal function $$f$$ to its forward or backward derivative propagation function. Slang's automatic differentiation feature provide built-in support for these operators to automatically generate the derivative propagation functions from a user defined primal function. The remaining documentation will discuss this feature from a programming language perspective. -## Differentiable Types -Slang will only generate differentiation code for values that has a *differentiable* type. A type is differentiable if it conforms to the built-in `IDifferentiable` interface. The definition of the `IDifferentiable` interface is: +## Differentiable Value Types +Slang will only generate differentiation code for values that has a *differentiable* type. +Differentiable types are defining through conformance to one of two built-in interfaces: +1. `IDifferentiable`: For value types (e.g. `float`, structs of value types, etc..) +2. `IDifferentiablePtrType`: For buffer, pointer & reference types that represent locations rather than values. + +The `IDifferentiable` interface requires the following definitions (which can be auto-generated by the compiler for most scenarios) ```csharp interface IDifferentiable { @@ -147,23 +152,38 @@ interface IDifferentiable static Differential dzero(); static Differential dadd(Differential, Differential); - - static Differential dmul(This, Differential); } ``` As defined by the `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of the second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`. -In addition, a differentiable type must define the `zero` value of its derivative, and how to add and multiply derivative values. +In addition, a differentiable type must define the `zero` value of its derivative, and how to add two derivative values together. These function are used during reverse-mode auto-diff, to initialize and accumulate derivatives of the given type. -### Builtin Differentiable Types +By contrast, `IDifferentiablePtrType` only requires a `Differential` associated type which also conforms to `IDifferentiablePtrType`. +```csharp +interface IDifferentiablePtrType +{ + associatedtype Differential : IDifferentiablePtrType; + where Differential.Differential == Differential; +} +``` + +Types should not conform to both `IDifferentiablePtrType` and `IDifferentiable`. Such cases will result in a compiler error. + + +### Builtin Differentiable Value Types The following built-in types are differentiable: - Scalars: `float`, `double` and `half`. - Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types. - Arrays: `T[n]` is differentiable if `T` is differentiable. +- Tuples: `Tuple` is differentiable if `T` is differentiable. + +### Builtin Differentiable Ptr Types +There are currently no built-in types that conform to `IDifferentiablePtrType` ### User Defined Differentiable Types -The user can make any `struct` types differentiable by implementing the `IDifferentiable` interface on the type. The requirements from `IDifferentiable` interface can be fulfilled automatically or manually. +The user can make any `struct` types differentiable by implementing either `IDifferentiable` & `IDifferentiablePtrType` interface on the type. +The requirements from `IDifferentiable` interface can be fulfilled automatically or manually, though `IDifferentiablePtrType` currently requires the user to provide the `Differential` type. #### Automatic Fulfillment of `IDifferentiable` Requirements Assume the user has defined the following type: @@ -191,7 +211,7 @@ Note that this code does not provide any explicit implementation of the `IDiffer 1. A new type is generated that stores the `Differential` of all differentiable fields. This new type itself will conform to the `IDifferentiable` interface, and it will be used to satisfy the `Differential` associated type requirement. 2. Each differential field will be associated to its corresponding field in the newly synthesized `Differential` type. 3. The `zero` value of the differential type is made from the `zero` value of each field in the differential type. -4. The `dadd` and `dmul` methods simply perform `dadd` and `dmul` operations on each field. +4. The `dadd` method invokes the `dadd` operations for each field whose type conforms to `IDifferentiable`. 5. If the synthesized `Differential` type contains exactly the same fields as the original type, and the type of each field is the same as the original field type, then the original type itself will be used as the `Differential` type instead of creating a new type to satisfy the `Differential` associated type requirement. This means that all the synthesized `Differential` type use itself to meet its own `IDifferentiable` requirements. #### Manual Fulfillment of `IDifferentiable` Requirements @@ -235,15 +255,6 @@ struct MyRay : IDifferentiable result.d_dir = v1.d_dir + v2.d_dir; return result; } - - // Define the multiply operation of a primal value and a derivative value. - static MyRayDifferential dmul(MyRay p, MyRayDifferential d) - { - MyRayDifferential result; - result.d_origin = p.origin * d.d_origin; - result.d_dir = p.dir * d.d_dir; - return result; - } } ``` @@ -284,14 +295,6 @@ struct MyRayDifferential : IDifferentiable result.d_dir = v1.d_dir + v2.d_dir; return result; } - - static MyRayDifferential dmul(MyRayDifferential p, MyRayDifferential d) - { - MyRayDifferential result; - result.d_origin = p.d_origin * d.d_origin; - result.d_dir = p.d_dir * d.d_dir; - return result; - } } ``` In this specific case, the automatically generated `IDifferentiable` implementation will be exactly the same as the manually written code listed above. @@ -303,17 +306,19 @@ Functions in Slang can be marked as forward-differentiable or backward-different A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters. Given an original function, the signature of its forward propagation function is determined using the following rules: -- If the return type `R` is differentiable, the forward propagation function will return `DifferentialPair` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. -- If a parameter has type `T` that is differentiable, it will be translated into a `DifferentialPair` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters. +- If the return type `R` implements `IDifferentiable` the forward propagation function will return a corresponding `DifferentialPair` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. +- If a parameter has type `T` that implements `IDifferentiable`, it will be translated into a `DifferentialPair` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters. +- If a parameter has type `T` that implements `IDifferentiablePtrType`, it will be translated into a `DifferentialPtrPair` parameter where the differential component references the differential location or buffer. - All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function. +- Differentiable methods cannot have a type implementing `IDifferentiablePtrType` as an `out` or `inout` parameter, or a return type. Types implementing `IDifferentiablePtrType` can only be used for input parameters to a differentiable method. Marking such a method as `[Differentiable]` will result in a compile-time diagnostic error. For example, given original function: ```csharp -R original(T0 p0, inout T1 p1, T2 p2); +R original(T0 p0, inout T1 p1, T2 p2, T3 p3); ``` -Where `R`, `T0`, and `T1` is differentiable and `T2` is non-differentiable, the forward derivative function will have the following signature: +Where `R`, `T0`, `T1 : IDifferentiable`, `T2` is non-differentiable, and `T3 : IDifferentiablePtrType`, the forward derivative function will have the following signature: ```csharp -DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2); +DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2, DifferentialPtrPair p3); ``` This forward propagation function takes the initial primal value of `p0` in `p0.p`, and the partial derivative of `p0` with regard to some upstream parameter in `p0.d`. It takes the initial primal and derivative values of `p1` and updates `p1` to hold the newly computed value and propagated derivative. Since `p2` is not differentiable, it remains unchanged. @@ -327,10 +332,11 @@ struct DifferentialPair : IDifferentiable property T.Differential d {get;} static Differential dzero(); static Differential dadd(Differential a, Differential b); - static Differential dmul(This a, Differential b); } ``` +For ptr-types, there is a corresponding built-in `DifferentialPtrPair` that does not have the `dzero` or `dadd` methods. + ### Automatic Implementation of Forward Derivative Functions A function can be made forward-differentiable with a `[ForwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the forward propagation function. The syntax for using `[ForwardDifferentiable]` is: @@ -392,23 +398,25 @@ Given an original function `f`, the general rule for determining the signature o More specifically, the signature of its backward propagation function is determined using the following rules: - A backward propagation function always returns `void`. -- A differentiable `in` parameter of type `T` will become an `inout DifferentialPair` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input. -- A differentiable `out` parameter of type `T` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value. -- A differentiable `inout` parameter of type `T` will become an `inout DifferentialPair` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated. +- A differentiable `in` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input. +- A differentiable `out` parameter of type `T : IDifferentiable` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value. +- A differentiable `inout` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated. - A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the backward propagation function parameter list, carrying the result derivative of a downstream term with regard to the return value of the original function. - A non-differentiable return value of type `NDR` will be dropped. - A non-differentiable `in` parameter of type `ND` will remain unchanged in the backward propagation function. - A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the backward propagation function. - A non-differentiable `inout` parameter of type `ND` will become an `in ND` parameter. +- Types implemented `IDifferentiablePtrType` work the same was as the forward-mode case. They can only be used with `in` parameters, and are converted into `DifferentialPtrPair` types. Their directions are not affected. For example consider the following original function: ```csharp struct T : IDifferentiable {...} struct R : IDifferentiable {...} +struct P : IDifferentiablePtrType {...} struct ND {} // Non differentiable [Differentiable] -R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5); +R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5, P p6); ``` The signature of its backward propagation function is: ```csharp @@ -418,6 +426,7 @@ void back_prop( inout DifferentialPair p2, ND p3, ND p5, + DifferentialPtrPair

p6, R.Differential dResult); ``` Note that although `p2` is still `inout` in the backward propagation function, the backward propagation function will only write propagated derivative to `p2.d` and will not modify `p2.p`. @@ -447,6 +456,7 @@ void back_prop( inout DifferentialPair p2, ND p3, ND p5, + DifferentialPtrPair

p6, R.Differential dResult) { ... @@ -468,6 +478,7 @@ void back_prop( inout DifferentialPair p2, ND p3, ND p5, + DifferentialPtrPair

p6, R.Differential dResult) { ...