Automatic differentiation (AD) is a set of techniques to calculate exact derivatives, numerically, in an automatic way. It is neither symbolic differentiation, nor something like finite differences.
There are two main methods: forward-mode AD and reverse-mode AD. Each has its strengths and weaknesses. Forward mode is significantly easier to implement.
Let's start by thinking about univariate functions $f: \mathbb{R} \to \mathbb{R}$. We would like to calculate the derivative $f'(a)$ at some point $a \in \mathbb{R}$.
We know various rules about how to calculate such derivatives. For example, if we have already managed to calculate $f'(a)$ and $g'(a)$, we can calculate $(f+g)'(a)$ and $(f.g)'(a)$ as
\begin{align} (f+g)'(a) &= f'(a) + g'(a)\\ (f.g)'(a) &= f'(a) \, g(a) + f(a) \, g'(a) \end{align}We also have the chain rule, which plays a crucial role:
$$(f \circ g)'(a) = f'(g(a)) \, g'(a)$$We see that in general we will need, for each function $f$, both the value $f(a)$ and the derivative $f'(a)$, and this is the only information that we require in order to calculate the first derivative of any combination of functions.
Formally, we can think of a first-order Taylor polynomial of $f$, called the [jet of $f$](https://en.wikipedia.org/wiki/Jet_(mathematics) at $a$, denoted $J_a(f)$:
$$(J_a(f))(x) := f(a) + x f'(a)$$[This can be thought of as representing the set of all functions with the same data $f(a)$ and $f'(a)$.]
Formally, it is common to think of this as a "dual number", $f + \epsilon f'$, that we can manipulate, following the rule that $\epsilon^2 = 0$. (Cf. complex numbers, which have the same structure, but with $\epsilon^2 = -1$.) E.g.
$$(f + \epsilon f') \times (g + \epsilon g') = f \, g + \epsilon (f' g + f g')$$shows how to define the multiplication of two jets.
As usual, we can represent a polynomial just by its degree and its coefficients, so we can define a Julia object as follows. We will leave the evaluation point $(a)$ as being implicit, although we could, of course, include it if desired.
immutable Jet{T} <: Real
val::T # value
der::T # derivative # type \prime<TAB> to get ′
end
import Base: +, *, -, convert, promote_rule
+(f::Jet, g::Jet) = Jet(f.val + g.val, f.der + g.der)
-(f::Jet, g::Jet) = Jet(f.val - g.val, f.der - g.der)
*(f::Jet, g::Jet) = Jet(f.val*g.val, f.der*g.val + f.val*g.der)
* (generic function with 150 methods)
We can now define Jet
s and manipulate them:
f = Jet(3, 4) # any function f such that f(a) = 3 and f'(a) = 4, or the set of all such functions
g = Jet(5, 6) # any function g such that g(a) = 5 and g'(a) = 6
f + g # calculate the value and derivative of (f + g) for any f and g in these sets
Jet{Int64}(8,10)
f * g
Jet{Int64}(15,38)
f * (g + g)
Jet{Int64}(30,76)
It seems like we must have introduced quite a lot of computational overhead by creating a relatively complex data structure, and associated methods, to manipulate pairs of numbers. Let's see how the performance is:
add(a1, a2, b1, b2) = (a1+b1, a2+b2)
add (generic function with 1 method)
add(1, 2, 3, 4)
@time add(1, 2, 3, 4)
0.000001 seconds (4 allocations: 176 bytes)
(4,6)
a = Jet(1, 2)
b = Jet(3, 4)
add2(j1, j2) = j1 + j2
add2(a, b)
@time add2(a, b)
0.000002 seconds (5 allocations: 192 bytes)
WARNING: Method definition add2(Any, Any) in module Main at In[157]:4 overwritten at In[158]:4.
Jet{Int64}(4,6)
@code_native add(1, 2, 3, 4)
.section __TEXT,__text,regular,pure_instructions Filename: In[151] pushq %rbp movq %rsp, %rbp Source line: 1 addq %rcx, %rsi addq %r8, %rdx movq %rsi, (%rdi) movq %rdx, 8(%rdi) movq %rdi, %rax popq %rbp retq nopw %cs:(%rax,%rax)
@code_native add2(a, b)
.section __TEXT,__text,regular,pure_instructions Filename: In[158] pushq %rbp movq %rsp, %rbp Source line: 4 movq (%rdx), %rax movq 8(%rdx), %rcx addq (%rsi), %rax addq 8(%rsi), %rcx movq %rax, (%rdi) movq %rcx, 8(%rdi) movq %rdi, %rax popq %rbp retq nop
We see that there is only a slight overhead to do with moving the data around. The data structure itself has disappeared, and we basically have a standard Julia tuple.
We can also define functions of these objects using the chain rule. For example, if f
is a jet representing the function $f$, then we would like exp(f)
to be a jet representing the function $\exp \circ f$, i.e. with value $\exp(f(a))$ and derivative $(\exp \circ f)'(a) = \exp(f(a)) \, f'(a)$:
import Base: exp
exp(f::Jet) = Jet(exp(f.val), exp(f.val) * f.der)
exp (generic function with 12 methods)
f
Jet{Int64}(3,4)
exp(f)
Jet{Float64}(20.085536923187668,80.34214769275067)
However, we can't do e.g. the following:
# 3 * f
[Warning: In Julia 0.5, you may need to restart the kernel after doing this for the following to work correctly.]
In order to get this to work, we need to hook into Julia's type promotion and conversion machinery.
First, we specify how to promote a number and a Jet
:
promote_rule{T<:Real,S}(::Type{Jet{S}}, ::Type{T}) = Jet{S}
promote_rule (generic function with 102 methods)
Second, we specify how to convert
a (constant) number to a Jet
. By e.g. $g = f+3$, we mean the function such that $g(x) = f(x) + 3$ for all $x$, i.e. $g = f + 3.\mathbb{1}$, where $\mathbb{1}$ is the constant function $\mathbb{1}: x \mapsto 1$.
Thus we think of a constant $c$ as the constant function $c \, \mathbb{1}$, with $c(a) = c$ and $c'(a) = 0$, which we encode as the following conversion:
convert{T<:Union{AbstractFloat, Integer, Rational},S}(::Type{Jet{S}}, x::T) = Jet{S}(x, 0)
convert (generic function with 600 methods)
convert(Jet{Float64}, 3.1)
Jet{Float64}(3.1,0.0)
promote(Jet(1,2), 3.0)
(Jet{Int64}(1,2),Jet{Int64}(3,0))
promote(Jet(1,2), 3.1)
InexactError() in promote(::Jet{Int64}, ::Float64) at promotion.jl:153
convert(Jet{Float64}, 3.0)
Jet{Float64}(3.0,0.0)
Julia's machinery now enables us to do what we wanted:
Jet(1.1, 2.3) + 3
Jet{Float64}(4.1,2.3)
How can we use this to calculate the derivative of an arbitrary function? For example, we wish to differentiate the function
h(x) = x^2 - 2
h (generic function with 1 method)
at $a = 3$.
We think of this as a function of $x$, which itself we think of as the identity function $\iota: x \mapsto x$, so that
We represent the identity function as follows:
a = 3
x = Jet(a, 1)
Jet{Int64}(3,1)
since $\iota'(a) = 1$ for any $a$.
Now we simply evaluate the function h
at x
:
h(x)
Jet{Int64}(7,6)
The first component of the resulting Jet
is the value $h(a)$, and the second component is the derivative, $h'(a)$.
We can codify this into a function as follows:
derivative(f, x) = f(Jet(x, one(x))).der
WARNING: Method definition derivative(Any, Any) in module Main at In[32]:1 overwritten at In[49]:1.
derivative (generic function with 1 method)
derivative(x -> 3x^5 + 2, 2)
240
This is capable of differentiating any function that involves functions whose derivatives we have specified by defining corresponding rules on Jet
objects. For example,
y = [1.,2]
k(x) = (y'* [x 2; 3 4] * y)[]
WARNING: Method definition k(Any) in module Main at In[29]:2 overwritten at In[34]:2.
k (generic function with 1 method)
k(3)
29.0
derivative(x->k(x), 10)
1
This works since Julia is constructing the following object:
[Jet(3.0, 1.0) 2; 3 4]
2×2 Array{Jet{Float64},2}: Jet{Float64}(3.0,1.0) Jet{Float64}(2.0,0.0) Jet{Float64}(3.0,0.0) Jet{Float64}(4.0,0.0)
How can we extend this to higher dimensions? For example, we wish to differentiate the following function $f: \mathbb{R}^2 \to \mathbb{R}$:
f1(x, y) = x^2 + x*y
f1 (generic function with 1 method)
As we learn in calculus, the partial derivative $\partial f/\partial x$ is the function obtained by fixing $y$, thinking of the resulting function as a function only of $x$, and then differentiating.
Suppose that we wish to differentiate $f$ at $(a, b)$:
a, b = 3.0, 4.0
f1_x(x) = f1(x, b) # single-variable function
WARNING: Method definition f1_x(Any) in module Main at In[47]:3 overwritten at In[50]:3.
f1_x (generic function with 1 method)
Since we now have a single-variable function, we can differentiate it:
derivative(f1_x, a)
10.0
Under the hood this is doing
f1(Jet(a, one(a)), b)
Jet{Float64}(21.0,10.0)
Similarly, we can differentiate with respect to $y$ by doing
f1(a, Jet(b, one(b)))
Jet{Float64}(21.0,3.0)
Note that we must do two separate calculations to get the two partial derivatives. To calculate a gradient of a function $f:\mathbb{R}^n \to \mathbb{R}$ thus requires $n$ separate calculations.
Forward-mode AD is implemented in a clean and efficient way in the ForwardDiff.jl
package.
To understand what forward-mode AD is doing, and its name, it is useful to think of an expression as a syntax tree; cf. [this notebook](Syntax trees in Julia.ipynb).
If we label the nodes in the tree as $v_i$, then forward differentiation fixes a variable, e.g. $y$, and calculates $\partial v_i / \partial y$ for each $i$. If e.g. $v_1 = v_2 + v_3$, then we have
$$\frac{\partial v_1}{\partial y} = \frac{\partial v_2}{\partial y} + \frac{\partial v_3}{\partial y}.$$Denoting $v_1' := \frac{\partial v_1}{\partial y}$, we have $v_1' = v_2' + v_3'$, so we need to calculate the derivatives and nodes lower down in the graph first, and propagate the information up. We start at $v_x' = 0$, since $\frac{\partial x}{\partial y} = 0$, and $v_y' = 1$.
An alternative method to calculate derivatives is to fix not the variable with which to differentiate, but what it is that we differentiate, i.e. to calculate the adjoint, $\bar{v_i} := \frac{\partial f}{\partial v_i}$, for each $i$.
If $f = v_1 + v_2$, with $v_1 = v_3 + v_4$ and $v_2 = v_3 + v_5$, then
$$\frac{\partial f}{\partial v_3} = \frac{\partial f}{\partial v_1} \frac{\partial v_1}{\partial v_3} + \frac{\partial f}{\partial v_2} \frac{\partial v_2}{\partial v_3},$$i.e.
$$\bar{v_3} = \alpha_{13} \, \bar{v_1} + \alpha_{2,3} \, \bar{v_2},$$where $\alpha_{ij}$ are the coefficients specifying the relationship between the different terms. Thus, the adjoint information propagates down the graph, in reverse order, hence the name "reverse-mode".
For this reason, reverse mode is much harder to implement. However, it has the advantage that all derivatives $\partial f / \partial x_i$ are calculated in a single pass of the tree.
Julia has en efficient implementation of reverse-mode AD in https://github.com/JuliaDiff/ReverseDiff.jl
Reverse mode is difficult to implement in a general way, but easy to do by hand. e.g. consider the function
We decompose this into its tree with labelled nodes, corresponding to the following sequence of elementary operations:
ff(x, y, z) = x*y - 2*sin(x*z)
x, y, z = 1, 2, 3
v₁ = x
v₂ = y
v₃ = z
v₄ = v₁ * v₂
v₅ = v₁ * v₃
v₆ = sin(v₅)
v₇ = v₄ - 2v₆ # f
WARNING: Method definition ff(Any, Any, Any) in module Main at In[137]:1 overwritten at In[139]:1.
1.7177599838802655
ff(x, y, z)
1.7177599838802655
We have decomposed the forward pass into elementary operations. We now proceed to calculate the adjoints. The difficulty is to find which variables depend on the current variable under question.
v̄₇ = 1
v̄₆ = -2 # ∂f/∂v₆ = ∂v₇/∂v₆
v̄₅ = v̄₆ * cos(v₅) # ∂v₇/∂v₆ * ∂v₆/∂v₅
v̄₄ = 1
v̄₃ = v̄₅ * v₁ # ∂f/∂v₃ = ∂f/∂v₅ . ∂v₅/∂v₃. # This gives ∂f/∂z
v̄₂ = v̄₄ * v₁
v̄₁ = v̄₅*v₃ + v̄₄*v₂
7.939954979602673
Thus, in a single pass we have calculated the gradient $\nabla f(1, 2, 3)$:
(v̄₁, v̄₂, v̄₃)
(7.939954979602673,1,1.9799849932008908)
Let's check that it's correct:
ForwardDiff.gradient(x->ff(x...), [x,y,z])
3-element Array{Float64,1}: 7.93995 1.0 1.97998
As an example of the use of AD, consider the following function that we wish to optimize:
x = rand(3)
y = rand(3)
distance(W) = W*x - y
distance (generic function with 1 method)
using ForwardDiff
ForwardDiff.jacobian(distance, rand(3,3))
3×9 Array{Float64,2}: 0.889986 0.0 0.0 0.855784 … 0.659763 0.0 0.0 0.0 0.889986 0.0 0.0 0.0 0.659763 0.0 0.0 0.0 0.889986 0.0 0.0 0.0 0.659763
objective(W) = (a = distance(W); dot(a, a))
WARNING: Method definition objective(Any) in module Main at In[60]:4 overwritten at In[66]:1.
objective (generic function with 1 method)
W0 = rand(3, 3)
grad = ForwardDiff.gradient(objective, W0)
3×3 Array{Float64,2}: 2.14718 2.06467 1.59175 0.023659 0.0227498 0.0175388 2.13258 2.05063 1.58092
2*(W0*x-y)*x' == grad # LHS is the analytical derivative
true
How can we find roots of a function?
f2(x) = x^2 - 2
WARNING: Method definition f2(Any) in module Main at In[100]:1 overwritten at In[108]:1.
f2 (generic function with 1 method)
An idea is to exclude regions of $\mathbb{R}$ by showing that they cannot contain a zero, by calculating the image (range) of the function over a given domain.
This is, in general, a difficult problem, but interval arithmetic provides a partial solution, by calculating an enclosure of the range, i.e. and interval that is guaranteed to contain the range.
using ValidatedNumerics
X = 3..4
[3, 4]
typeof(X)
ValidatedNumerics.Interval{Float64}
This is a representation of the set $X = [3, 4] := \{x\in \mathbb{R}: 3 \le x \le 4\}$.
We can evaluate a Julia function on an Interval
object X
. The result is a new Interval
, which is guaranteed to contain the true image $\mathrm{range}(f; X) := \{f(x): x \in X \}$. This is achieved by defining arithmetic operations on intervals in the correct way, e.g.
f2(X)
[7, 14]
Since this result does not contain $0$, we have proved that $f$ has no zero in the domain $[3,4]$. We can even use semi-infinite intervals:
X1 = 3..∞ # type \infty<TAB>
[3, ∞]
f2(X1)
[7, ∞]
X2 = -∞.. -3 # space is required
[-∞, -3]
f2(X2)
[7, ∞]
We have thus exclued two semi-infinite regions, and have proved that any root must lie in $[-3,3]$, by two simple calculations. However,
f2(-3..3)
[-2, 7]
We cannot conclude anything from this, since the result is, in general, an over-estimate of the true range, which thus may or may not contain zero. We can proceed by bisecting the interval. E.g. after two bisections, we find
f2(-3.. -1.5)
[0.25, 7]
so we have excluded another piece.
To prove that there does exist a root, we need a different approach. It is a standard method to evaluate the function at two end-points of an interval:
f2(1), f2(2)
(-1,2)
Since there is a sign change, there exists at least one root $x^*$ in the interval $[1,2]$, i.e. a point such that $f(x^*) = 0$.
To prove that it is unique, one method is to prove that $f_2$ is monotone in that interval, i.e. that the derivative has a unique sign. To do so, we need to evaluate the derivative at every point in the interval, which seems impossible.
Again, however, interval arithmetic easily gives an enclosure of this image. To show this, we need to evaluate the derivative using interval arithmetic.
Thanks to Julia's parametric types, we get composability for free: we can just substitute in an interval to ForwardDiff
or Jet
, and it works:
ForwardDiff.derivative(f2, 1..2)
[2, 4]
Again, the reason for this is that Julia creates the object
Jet(x, one(x))
Jet{ValidatedNumerics.Interval{Float64}}([1, 2],[1, 1])
Since an enclosure of the derivative is the interval $[2, 4]$ (and, in fact, in this case this is the true image, but there is no way to know this other than with an analytical calculation), we have proved that the image of the derivative function $f'$ over the interval $X = [1,2]$ does not contain zero, and hence that the image is monotone.
To actually find the root within this interval, we can use the [Newton interval method](Interval Newton.ipynb). In general, we should not expect to be able to use intervals in standard numerical methods designed for floats; rather, we will need to modify the numerical method to take advantage of intervals.
The Newton interval method can find, in a guaranteed way, all roots of a function in a given interval (or tell you if when it is unable to to so, for example if there are double roots). Although people think that finding roots of a general function is difficult, this is basically a solved problem using these methods.