import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
Consider the same code as in the lecture, to create a Mandelbrot fractal:
def get_grid(nh,nv):
x = np.linspace(-2,.8,nh)
y = np.linspace(-1.4,1.4,nv)
return np.meshgrid(x,y,indexing='ij')
def mandel_py(w,h,maxit=20):
# prepare initial points
x, y = get_grid(w,h)
c = x+y*1j
# where to store output
output = np.zeros(c.shape, dtype=int) + maxit
for i in range(h): # loop 1
for j in range(w): # loop 2
z = 0.
c0 = c[i,j]
for k in range(maxit): # loop 3!!
z = z**2 + c0
if z*z.conjugate() > 4.0:
output[i, j] = k
break
return output.T
plt.imshow(mandel_py(100,100))
<matplotlib.image.AxesImage at 0x106a1af28>
%timeit mandel_py(100,100)
10 loops, best of 3: 179 ms per loop
Task: optimize the code above using NumPy tricks only (vectorization).
Hint: use smart indexing
def mandel_np(w,h,maxit=20):
# prepare initial points
x, y = get_grid(w,h)
c = x+y*1j
# where to store output
output = np.zeros_like(c, dtype=int) + maxit
z = np.zeros_like(c)
for k in range(maxit):
z = z**2 + c
mask = z*z.conjugate() <= 4.0
output[mask] = k
return output.T
%timeit mandel_np(100, 100) # make sure it is faster
100 loops, best of 3: 2.81 ms per loop
/Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:10: RuntimeWarning: overflow encountered in multiply /Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:10: RuntimeWarning: invalid value encountered in multiply /Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:9: RuntimeWarning: overflow encountered in square /Users/olivier/anaconda/envs/python3/lib/python3.4/site-packages/ipykernel/__main__.py:9: RuntimeWarning: invalid value encountered in square