This is one of the 100 recipes of the IPython Cookbook, the definitive guide to high-performance scientific computing and data science in Python.
Stride tricks can be useful for local computations on arrays, when the computed value at a given position depends on the neighbor values. Examples include dynamical systems, filters, cellular automata, and so on. In this example, we will implement an efficient rolling average (a particular type of convolution-based linear filter) with NumPy stride tricks.
The idea is to start from a 1D vector, and make a "virtual" 2D array where each line is a shifted version of the previous line. When using stride tricks, this process does not involve any copy, so it is efficient.
import numpy as np from numpy.lib.stride_tricks import as_strided %precision 0
def id(x): # This function returns the memory # block address of an array. return x.__array_interface__['data']
n = 5; k = 2
a = np.linspace(1, n, n); aid = id(a)
Let's change the strides of
a to add shifted rows.
as_strided(a, (k, n), (a.itemsize, a.itemsize))
id(a), id(as_strided(a, (k, n)))
The last value indicates an out-of-bounds problem: stride tricks can be dangerous as memory access is not checked. Here, we should take edge effects into account by limiting the shape of the array.
as_strided(a, (k, n - k + 1), (a.itemsize,)*2)
Let's apply this technique to calculate the rolling average of a random increasing signal.
First version using array copies.
def shift1(x, k): return np.vstack([x[i:n-k+i+1] for i in range(k)])
Second version using stride tricks.
def shift2(x, k): return as_strided(x, (k, n - k + 1), (8, 8))
b = shift1(a, k); b, id(b) == aid
c = shift2(a, k); c, id(c) == aid
Let's generate a signal.
n, k = 100, 10 t = np.linspace(0., 1., n) x = t + .1 * np.random.randn(n)
We compute the signal rolling average by creating the shifted version of the signal, and averaging along the vertical dimension.
y = shift2(x, k) x_avg = y.mean(axis=0)
Let's plot the signal and its averaged version.
import matplotlib.pyplot as plt
f = plt.figure() plt.plot(x[:-k+1], '-k'); plt.plot(x_avg, '-r');
Let's benchmark the first version (creation of the shifted array, and computation of the mean), which involves array copy.
%timeit shift1(x, k)
%%timeit y = shift1(x, k) z = y.mean(axis=0)
And the second version, using stride tricks.
%timeit shift2(x, k)
%%timeit y = shift2(x, k) z = y.mean(axis=0)
In the first version, most of the time is spent in the array copy, whereas in the stride trick version, most of the time is instead spent in the computation of the average.