I want to implement and illustrate the Runge-Kutta method (actually, different variants), in the Julia programming language.
The Runge-Kutta methods are a family of numerical iterative algorithms to approximate solutions of Ordinary Differential Equations. I will simply implement them, for the mathematical descriptions, I let the interested reader refer to the Wikipedia page, or any good book or course on numerical integration of ODE.
versioninfo()
For comparison, let's use this mature and fully featured package DifferentialEquations
that provides a solve
function to numerically integrate ordinary different equations, and the Plots
package with PyPlot
backend for plotting:
# If needed:
#Pkg.add("DifferentialEquations")
#Pkg.add("PyPlot")
#Pkg.add("Plots")
using Plots
gr()
using DifferentialEquations
I will use as a first example the one included in the scipy (Python) documentation for this odeint
function.
If $\omega(t) := \theta'(t)$, this gives $$ \begin{cases} \theta'(t) = \omega(t) \\ \omega'(t) = -b \omega(t) - c \sin(\theta(t)) \end{cases} $$
Vectorially, if $y(t) = [\theta(t), \omega(t)]$, then the equation is $y' = f(t, y)$ where $f(t, y) = [y_2(t), -b y_2(t) - c \sin(y_1(t))]$.
We assume the values of $b$ and $c$ to be known, and the starting point to be also fixed:
b = 0.25
c = 5.0
y0 = [pi - 0.1; 0.0]
function pend(t, y, dy)
dy[1] = y[2]
dy[2] = (-b * y[2]) - (c * sin(y[1]))
end
function f_pend(y, t)
return [y[2], (-b * y[2]) - (c * sin(y[1]))]
end
The solve
function from DifferentialEquations
will be used to solve this ODE on the interval $t \in [0, 10]$.
tspan = (0.0, 10.0)
It is used like this, and our implementations will follow this signature.
function odeint_1(f, y0, tspan)
prob = ODEProblem(f, y0, tspan)
sol = solve(prob)
return sol.t, hcat(sol.u...)'
end
function odeint(f, y0, tspan)
t, sol = odeint_1(f, y0, tspan)
return sol
end
t, sol = odeint_1(pend, y0, tspan)
plot(t, sol[:, 1], xaxis="Time t", title="Solution to the pendulum ODE", label="\\theta (t)")
plot!(t, sol[:, 2], label="\\omega (t)")
The approximation is computed using this update: $$y_{n+1} = y_n + (t_{n+1} - t_n) f(y_n, t_n).$$
The math behind this formula are the following: if $g$ is a solution to the ODE, and so far the approximation is correct, $y_n \simeq g(t_n)$, then a small step $h = t_{n+1} - t_n$ satisfy $g(t_n + h) \simeq g(t_n) + h g'(t_n) \simeq y_n + h f(g(t_n), t_n) + \simeq y_n + h f(y_n, t_n)$.
function rungekutta1(f, y0, t)
n = length(t)
y = zeros((n, length(y0)))
y[1,:] = y0
for i in 1:n-1
h = t[i+1] - t[i]
y[i+1,:] = y[i,:] + h * f(y[i,:], t[i])
end
return y
end
t = linspace(0, 10, 101);
sol = rungekutta1(f_pend, y0, t);
plot(t, sol[:, 1], xaxis="Time t", title="Solution to the pendulum ODE with Runge-Kutta 1", label="\\theta (t)")
plot!(t, sol[:, 2], label="\\omega (t)")
With the same number of points, the Euler method (i.e. the Runge-Kutta method of order 1) is less precise than the reference solve
method. With more points, it can give a satisfactory approximation of the solution:
t2 = linspace(0, 10, 1001);
sol2 = rungekutta1(f_pend, y0, t2);
plot(t2, sol2[:, 1], xaxis="Time t", title="Solution to the pendulum ODE with Runge-Kutta 1", label="\\theta (t)")
plot!(t2, sol2[:, 2], label="\\omega (t)")
t3 = linspace(0, 10, 2001);
sol3 = rungekutta1(f_pend, y0, t3);
plot(t3, sol3[:, 1], xaxis="Time t", title="Solution to the pendulum ODE with Runge-Kutta 1", label="\\theta (t)")
plot!(t3, sol3[:, 2], label="\\omega (t)")
The order 2 Runge-Method uses this update: $$ y_{n+1} = y_n + h f(t + \frac{h}{2}, y_n + \frac{h}{2} f(t, y_n)),$$ if $h = t_{n+1} - t_n$.
function rungekutta2(f, y0, t)
n = length(t)
y = zeros((n, length(y0)))
y[1,:] = y0
for i in 1:n-1
h = t[i+1] - t[i]
y[i+1,:] = y[i,:] + h * f(y[i,:] + f(y[i,:], t[i]) * h/2, t[i] + h/2)
end
return y
end
For our simple ODE example, this method is already quite efficient.
t3 = linspace(0, 10, 21);
sol3 = rungekutta2(f_pend, y0, t3);
plot(t3, sol3[:, 1], xaxis="Time t", title="Solution to the pendulum ODE with Runge-Kutta 2 (21 points)", label="\\theta (t)")
plot!(t3, sol3[:, 2], label="\\omega (t)")