Klaus Reygers, 2021
Automatic differentiation (autodiff, AD) is an efficient and numerically stable way to calculate dervatives on a computer. The algorithm requires at most a small constant factor more arithmetic operations than the original program.
Automatic differentiation is distinct from symbolic differentiation and numerical differentiation.
The basic idea is to supplement the standard mathematical functions so that in addition to the function value also the derivative is calculated. The derivative of a composite function (a function representing a sequence of primitive operations which have specified routines for computing derivatives) is then obtained by applying the chain rule repeatedly.
Training a neural network through backpropagtion is a typical application of autodiff (TensorFlow, PyTorch, ...).
Links:
from sympy import *
from IPython.display import display, Latex
def gaussian_error_propagation(f, vars):
"""
f: formula (sympy expression)
vars: list of independent variables and corresponding uncertainties
[(x1, sigma_x1), (x2, sigma_x2), ...]
"""
sum = S(0) # empty sympy expression
for (x, sigma) in vars:
sum += diff(f, x)**2 * sigma**2
return sqrt(simplify(sum))
Show usage for a simple example: Volume of a cylinder with radius $r$ and height $h$:
r, h, sigma_r, sigma_h = symbols('r, h, sigma_r, sigma_h', positive=True)
V = pi * r**2 * h # volume of a cylinder
sigma_V = gaussian_error_propagation(V, [(r, sigma_r), (h, sigma_h)])
display(Latex(f"$V = {latex(V)}, \, \sigma_V = {latex(sigma_V)}$"))
Plug in some numbers and print the calculated volume with its uncertaity:
r_meas = 3 # cm
sigma_r_meas = 0.1 # cm
h_meas = 5 # cm
sigma_h_meas = 0.1 # cm
central_value = V.subs([(r,r_meas), (h, h_meas)]).evalf()
sigma = sigma_V.subs([(r, r_meas), (sigma_r, sigma_r_meas), (h, h_meas), (sigma_h, sigma_h_meas)]).evalf()
display(Latex(f"$$V = ({central_value:0.1f} \pm {sigma:.1f}) \, \mathrm{{cm}}^3$$"))
from jax import grad, jacfwd
import jax.numpy as jnp
def error_prop_jax_gen(f,x,dx):
jac = jacfwd(f)
return jnp.sqrt(jnp.sum(jnp.power(jac(x)*dx,2)))
# volume of a cylinder with (x[0] = radius, x[1] = height)
def f(x):
return jnp.pi * x[1] * x[0]**2
x = jnp.array([3.,5.])
dx = jnp.array([0.1,0.1])
print (f"V = {f(x):0.1f} +/- {error_prop_jax_gen(f, x, dx):0.1f} cm**3")
V = 141.4 +/- 9.8 cm**3