Exercise for the course Python for MATLAB users, by Olivier Verdier

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Consider the same code as in the lecture, to create a Mandelbrot fractal:

In [2]:
def get_grid(nh,nv):
    x = np.linspace(-2,.8,nh)
    y = np.linspace(-1.4,1.4,nv)
    return np.meshgrid(x,y,indexing='ij')
In [3]:
def mandel_py(w,h,maxit=20):
    # prepare initial points
    x, y = get_grid(w,h)
    c = x+y*1j
    # where to store output
    output = np.zeros(c.shape, dtype=int) + maxit
    for i in range(h): # loop 1
        for j in range(w): # loop 2
            z = 0.
            c0 = c[i,j]
            for k in range(maxit): # loop 3!!
                z = z**2 + c0
                if z*z.conjugate() > 4.0:
                    output[i, j] = k
                    break
    return output.T
In [4]:
plt.imshow(mandel_py(100,100))
Out[4]:
<matplotlib.image.AxesImage at 0x106a1af28>
In [5]:
%timeit mandel_py(100,100)
10 loops, best of 3: 179 ms per loop

Task: optimize the code above using NumPy tricks only (vectorization).

Hint: use smart indexing

In [6]:
def mandel_np(w,h,maxit=20):
    # prepare initial points
    x, y = get_grid(w,h)
    c = x+y*1j
    # where to store output
    output = np.zeros_like(c, dtype=int) + maxit
    z = np.zeros_like(c)
    for k in range(maxit):
        z = z**2 + c
        mask = z*z.conjugate() <= 4.0
        output[mask] = k
    return output.T
In [7]:
%timeit mandel_np(100, 100) # make sure it is faster
100 loops, best of 3: 2.81 ms per loop
/Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:10: RuntimeWarning: overflow encountered in multiply
/Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:10: RuntimeWarning: invalid value encountered in multiply
/Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:9: RuntimeWarning: overflow encountered in square
/Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:9: RuntimeWarning: invalid value encountered in square