Exercise for the course Python for MATLAB users, by Olivier Verdier

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Consider the same code as in the lecture, to create a Mandelbrot fractal:

In [2]:
def get_grid(nh,nv):
    x = np.linspace(-2,.8,nh)
    y = np.linspace(-1.4,1.4,nv)
    return np.meshgrid(x,y,indexing='ij')
In [3]:
def mandel_py(w,h,maxit=20):
    # prepare initial points
    x, y = get_grid(w,h)
    c = x+y*1j
    # where to store output
    output = np.zeros(c.shape, dtype=int) + maxit
    for i in range(h): # loop 1
        for j in range(w): # loop 2
            z = 0.
            c0 = c[i,j]
            for k in range(maxit): # loop 3!!
                z = z**2 + c0
                if z*z.conjugate() > 4.0:
                    output[i, j] = k
                    break
    return output.T
In [5]:
plt.imshow(mandel_py(100,100))
Out[5]:
<matplotlib.image.AxesImage at 0x106a88d30>
In [15]:
%timeit mandel_py(100,100)
1 loops, best of 3: 242 ms per loop

Task: optimize the code above using NumPy tricks only (vectorization).

Hint: use smart indexing

In [16]:
def mandel_np(w,h,maxit=20):
    pass # implement this!
In [17]:
%timeit mandel_np(100, 100) # make sure it is faster
The slowest run took 12.86 times longer than the fastest. This could mean that an intermediate result is being cached 
10000000 loops, best of 3: 146 ns per loop