-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathautodiff.qmd
284 lines (208 loc) · 6.89 KB
/
autodiff.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
---
title : "Automatic Differentiation"
author: Paul Schrimpf
date: last-modified
bibliography: 622.bib
execute:
echo: true
cache: true
freeze: auto
format:
revealjs:
theme: blood
smaller: true
min-scale: 0.1
max-scale: 3.0
chalkboard:
theme: whiteboard
boardmarker-width: 2
chalk-width: 2
chalk-effect: 0.0
engine: julia
---
# Introduction
## Derivatives
- Needed for efficient equation solving and optimization
- Can calculate automatically
## Finite Differences
```{julia}
f(x) = sin(x)/(1.0+exp(x))
function dxfin(f,x)
h = sqrt(eps(x))
if abs(x) > 1
h = h*abs(x)
end
(f(x+h) - f(x) )/ h
end
dxfin(f, 2.0)
```
## Forward Automatic Differentiation
```{julia}
module Forward
struct Dual{T}
v::T
dv::T
end
Dual(x::T) where {T} = Dual(x, one(x))
import Base: +, sin, exp, *, /
function (+)(a::T, x::Dual{T}) where {T}
Dual(a+x.v, x.dv)
end
function (*)(y::Dual, x::Dual)
Dual(y.v*x.v, x.v*y.dv + x.dv*y.v)
end
function (/)(x::Dual, y::Dual)
Dual(x.v/y.v, x.dv/y.v - x.v*y.dv/y.v^2)
end
exp(x::Dual) = Dual(exp(x.v), exp(x.v)*x.dv)
sin(x::Dual) = Dual(sin(x.v), cos(x.v)*x.dv)
function fdx(f,x)
out=f(Dual(x))
(out.v, out.dv)
end
end
Forward.fdx(f,2.0)
```
## Reverse Automatic Differentiation
- compute $f(x)$ in usual forward direction, keep track of each operation and intermediate value
- compute derivative "backwards"
- $f(x) = g(h(x))$
- $f'(x) = g'(h(x)) h'(x)$
- scales better for high dimensional $x$
- implementation more complicated
- Simple-ish example https://simeonschaub.github.io/ReverseModePluto/notebook.html
# Julia AD Packages
## ForwardDiff
- [`ForwardDiff.jl`](https://github.com/JuliaDiff/ForwardDiff.jl)
- mature and reliable
## ForwardDiff Example
```{julia}
using Distributions
function simulate_logit(observations, β)
x = randn(observations, length(β))
y = (x*β + rand(Logistic(), observations)) .>= 0.0
return((y=y,x=x))
end
function logit_likelihood(β,y,x)
p = map(xb -> cdf(Logistic(),xb), x*β)
sum(log.(ifelse.(y, p, 1.0 .- p)))
end
n = 500
k = 3
β0 = ones(k)
(y,x) = simulate_logit(n,β0)
import ForwardDiff
∇L = ForwardDiff.gradient(b->logit_likelihood(b,y,x),β0)
```
## ForwardDiff Notes
- For $f: \mathbb{R}^n \to \mathbb{R}^m$, the computation scales with $n$
- best for moderate $n$
- Code must be generic
- be careful when allocating arrays
```{julia}
#| eval: false
#| echo: true
function wontwork(x)
y = zeros(eltype(x),size(x))
for i ∈ eachindex(x)
y[i] += x[i]*i
end
return(sum(y))
end
function willwork(x)
y = zero(x)
for i ∈ eachindex(x)
y[i] += x[i]*i
end
return(sum(y))
end
betterstyle(x) = sum(v*i for (i,v) in enumerate(x))
```
## Zygote
- [Zygote.jl](https://fluxml.ai/Zygote.jl/latest/)
- Does not allow mutating arrays
- Quite mature, but possibly some bugs remain
- Apparently hard to develop, unclear future
## Zygote Example
```{julia}
import Zygote
using LinearAlgebra
@time ∇Lz = Zygote.gradient(b->logit_likelihood(b,y,x),β0)[1]
norm(∇L - ∇Lz)
```
## Enzyme
["Enzyme performs automatic differentiation (AD) of statically analyzable LLVM. It is highly-efficient and its ability to perform AD on optimized code allows Enzyme to meet or exceed the performance of state-of-the-art AD tools."](https://enzymead.github.io/Enzyme.jl/stable/)
```{julia}
import Enzyme
import Enzyme: Active, Duplicated, Const
db = zero(β0)
@time Enzyme.autodiff(Enzyme.ReverseWithPrimal,logit_likelihood, Active, Duplicated(β0,db), Const(y), Const(x))
db
```
## Enzyme Notes
- Documentation is not suited to beginners
- Does not work on all Julia code, but cases where it fails are not well documented. Calling `Enzyme.API.runtimeActivity!(true)` works around some errors.
- Cryptic error messages. Enzyme operates on LLVM IR, and error messages often reference the point in the LLVM IR where the error occurred. Figuring out what Julia code the LLVM IR corresponds to is not easy.
- These may be better now than last year when I first wrote this slide
```{julia}
Enzyme.API.runtimeActivity!(false)
f1(a,b) = sum(a.*b)
dima = 30000
a = ones(dima)
b = rand(dima)
da = zeros(dima)
@time Enzyme.autodiff(Enzyme.ReverseWithPrimal, f1, Duplicated(a,da),Const(b))
da
f3(a,b) = sum(a[i]*b[i] for i ∈ eachindex(a))
da = zeros(dima)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f3, Duplicated(a,da),Const(b))
da
if (false) # will trigger enzyme error without runtimeactivity
f2(a,b) = sum(a*b for (a,b) ∈ zip(a,b))
da = zeros(dima)
@time Enzyme.autodiff(Enzyme.ReverseWithPrimal, f2, Duplicated(a,da), Const(b))
da
end
Enzyme.API.runtimeActivity!(true)
f2(a,b) = sum(a*b for (a,b) ∈ zip(a,b))
da = zeros(dima)
@time Enzyme.autodiff(Enzyme.ReverseWithPrimal, f2, Duplicated(a,da), Const(b))
da
```
## FiniteDiff
- [FiniteDiff](https://github.com/JuliaDiff/FiniteDiff.jl) computes finite difference gradients-- always test that whatever automatic or manual derivatives you compute are close to the finite difference versions
- use a package for finite differences to handle rounding error well
## ChainRules
- [ChainRules](https://github.com/JuliaDiff/ChainRules.jl)
- used by many AD packages to define the derivatives of various functions.
- Useful if you want to define a custom derivative rule for a function.
## DifferentiationInterface
- [DifferentiationInterface](https://github.com/gdalle/DifferentiationInterface.jl) gives a single interface for many differentiation packages
```{julia}
import DifferentiationInterface as DI
DI.gradient(b->logit_likelihood(b,y,x), DI.AutoEnzyme(),β0)
```
- improve performance by reusing intermediate variables
```{julia}
backend = DI.AutoEnzyme()
dcache = DI.prepare_gradient(b->logit_likelihood(b,y,x), backend, β0)
grad = zero(β0)
DI.gradient!(b->logit_likelihood(b,y,x),grad, backend,β0 , dcache)
```
# Other Packages
## Other Packages
- [https://juliadiff.org/](https://juliadiff.org/)
## ReverseDiff.jl
- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) a tape based reverse mode package
- Long lived and well tested
- [limitations](https://juliadiff.org/ReverseDiff.jl/limits/). Importantly, code must be generic and mutation of arrays is not allowed.
## Yota.jl
- [Yota.jl](https://github.com/dfdx/Yota.jl) another tape based package
- Compatible with Chainrules.jl
- Somewhat newer and less popular
- [Its documentation has a very nice explanation of how it works.](https://dfdx.github.io/Yota.jl/dev/design/)
## Tracker
[Tracker](https://github.com/FluxML/Tracker.jl) is a tape based reverse mode package. It was the default autodiff package in Flux before being replaced by Zygote. No longer under active development.
## Diffractor
[Diffractor](https://github.com/JuliaDiff/Diffractor.jl) is automatic differentiation package in development. It was once hoped to be the future of AD in Julia, but has been delayed. It plans to have both forward and reverse mode, but only forward mode is available so far.
# References