Faster data processing in Python

by S Anand (@sanand0)

Working with data in Python requires making a number of choices.

  • What format should I save data in? CSV? JSON? Pickle?
  • What library should I read data with?
  • How should I parse dates?
  • How to make my program restartable?
  • How do I run a function in parallel?

Let's discuss how to make these choices with the aim of running code faster.

Code at https://github.com/sanand0/ipython-notebooks

Visualising the elections

This is the story of the back-end optimisations we were making during the CNN-IBN - Microsoft election visualisations in May 2014.

In [1]:
from IPython.display import YouTubeVideo
In [2]:
YouTubeVideo('bMicxDcxefs', width=853, height=480, start=30)
Out[2]:

Let's start with the raw data

The first step in data processing is to load the data. We'd scraped the results of assembly elections to see which consitituency had the largest number of candidates, votes, etc. This is what the raw data file looks like:

In [3]:
assembly_results = 'D:/site/gramener.com/viz/election/assembly.csv'
print open(assembly_results).read(350)
ST_NAME,YEAR,AC_NO,#,AC_NAME,AC_TYPE,NAME,SEX,AGE,CATEGORY,PARTY,VOTES
Andhra Pradesh,1955,1,1,ICHCHAPURAM,GEN,UPPADA RANGABABU,M,,,KLP,14565
Andhra Pradesh,1955,1,2,ICHCHAPURAM,GEN,HARIHARA PATNAIK,M,,,IND,7408
Andhra Pradesh,1955,1,3,ICHCHAPURAM,GEN,PUDI LOKANADHAM,M,,,IND,6508
Andhra Pradesh,1955,1,4,ICHCHAPURAM,GEN,KALLA BALARAMA SWAMY,M,,,IND,

Who got the most votes?

An "entrance exam" question we have at Gramener is to ask candidates to load a CSV file and sort it in descending order of a given column. For example, when told "Write a program to find out who got the most votes in Bangalore South, ever", here's the most common response we get:

In [4]:
data = []                                       # Store results
row = 0                                         # Row number in data file
for line in open(assembly_results):             # Loop through each line
    row += 1                                    # Increment row counter
    if row > 1:                                 # Ignore the first line
        csv_row = line.split(',')               #   Split by commas
        if csv_row[4] == 'BANGALORE SOUTH':     #   AC_NAME is in 5th column
            name = csv_row[6]                   #   name is the 7th column
            votes = int(csv_row[11].strip())    #   votes (12th) may have trailing \n
            data.append([name, votes])          # 6 = NAME, 11 = VOTES

import operator
sorted(data, key=operator.itemgetter(1), reverse=True)    # Sort in descending order of votes
Out[4]:
[['M. KRISHNAPPA', 102207],
 ['R. PRABHAKARA REDDY', 72045],
 ['M KRISHNAPPA', 71114],
 ['DR.TEJASWINI GOWDA', 63849],
 ['SADANANDA M', 36979],
 ['C MANJUNATH', 33529],
 ['H P RAJAGOPALA REDDY', 17726],
 ['D. MUNICHINNAPPA', 17441],
 ['A. V. NARASIMHA REDDY', 13702],
 ['R. RANGAPPA REDDY', 13452],
 ['B. BASAVALINGAPPA (SC)', 12365],
 ['B. BASAVALINGAPPA', 11540],
 ['B. T. KEMPA RAJ (SC)', 5449],
 ['C. BASAVIAH', 5333],
 ['JAGADISH REDDY', 3936],
 ['SHIVARUDRAPPA', 2610],
 ['K. CHIKKANNA', 2561],
 ['KRISHNAMA RAJU', 2509],
 ['R MANJUNATH', 2300],
 ['P. S. CHINNAPPA', 2229],
 ['MURALI MOHAN', 1975],
 ['A SOMASHEKAR', 1755],
 ['M. NARAYANAPPA', 1494],
 ['N S RAVICHANDRA', 1494],
 ['VASANTH', 1363],
 ['G.A. GREGORY', 1347],
 ['E KRISHNAPPA', 1040],
 ['M KRISHNAPPA', 990],
 ['LOKESH M R', 771],
 ['A. CHOWARAPPA', 691],
 ['T.N KAMAL', 639],
 ['N.S. RAVICHANDRA', 509],
 ['JAYARAMA Y', 410],
 ['K GURURAJ', 407],
 ['ASHISH KAPUR', 394],
 ['SUNDARAPPA', 335]]

Before you optimise, time it

But let's see how we can make the program faster. The first step towards that is to see how long it takes.

In [5]:
import operator

def most_votes(ac_name='BANGALORE SOUTH'):
    data = []
    row = 0
    for line in open(assembly_results):
        row += 1
        if row > 1:
            csv_row = line.split(',')
            if csv_row[4] == ac_name:
                name = csv_row[6]
                votes = int(csv_row[11].strip())
                data.append([name, votes])

    return sorted(data, key=operator.itemgetter(1), reverse=True)

%timeit most_votes('BANGALORE_SOUTH')               # %timeit is IPython's timing magic
1 loops, best of 3: 387 ms per loop

Is it fast enough?

That's a fairly slow function. But what's taking up so much time? Can we narrow down to the line that takes up the most time?

At this time, I'd like to quote Calvin.

Do I even care?

Premature optimisation is the root of all evil. Make sure you get the functionality right first. Then ask the question, "Do you even care that it's slow?" Only if the answer is yes do we proceed.

At this point, find and optimise the slowest part. Only the slowest part.

Find the slowest part

For this, we'll use the line_profiler module.

pip install line_profiler

You'll need some more setup to make this an extension. See http://pynash.org/2013/03/06/timing-and-profiling.html for details.

Once set up, you can load it like this:

In [6]:
%load_ext line_profiler

%lprun -f most_votes most_votes()

This displays it in a separate window in IPython Notebook. So let's create a new function that displays it as output.

In [7]:
import line_profiler

def lprun(func, *args, **kwargs):
    profile = line_profiler.LineProfiler(func)
    profile.runcall(func, *args, **kwargs)
    profile.print_stats()
In [8]:
lprun(most_votes)
Timer unit: 4.66512e-07 s

Total time: 1.98855 s
File: <ipython-input-5-299e00e75fb3>
Function: most_votes at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def most_votes(ac_name='BANGALORE SOUTH'):
     4         1            6      6.0      0.0      data = []
     5         1            2      2.0      0.0      row = 0
     6    398704       866654      2.2     20.3      for line in open(assembly_results):
     7    398703       694258      1.7     16.3          row += 1
     8    398703       623570      1.6     14.6          if row > 1:
     9    398702      1317660      3.3     30.9              csv_row = line.split(',')
    10    398702       760068      1.9     17.8              if csv_row[4] == ac_name:
    11        36           54      1.5      0.0                  name = csv_row[6]
    12        36          161      4.5      0.0                  votes = int(csv_row[11].strip())
    13        36           94      2.6      0.0                  data.append([name, votes])
    14                                           
    15         1           56     56.0      0.0      return sorted(data, key=operator.itemgetter(1), reverse=True)

Pick your battle

Now that we know what's taking time, there are two approaches to pick what to optimise:

  1. Eliminate what's obviously redundant
  2. Optimise what takes the most time
    1. Reduce the number of Hits
    2. Reduce the time per hit

Eliminate the obviously redundant

Let's look at whether anything is obviously redundant. Consider the line:

if row > 1

That's being checked 398,703 times. But in reality, it needs to be checked only once. We just want to ignore the first row. So let's refactor the function:

In [9]:
def most_votes_skip_first(ac_name='BANGALORE SOUTH'):
    data = []
    handle = open(assembly_results)
    handle.next()  # Skip the header
    for line in handle:
        csv_row = line.split(',')
        if csv_row[4] == ac_name:
            name = csv_row[6]
            votes = int(csv_row[11].strip())
            data.append([name, votes])

    return sorted(data, key=operator.itemgetter(1), reverse=True)

time1 = %timeit -o most_votes()
time2 = %timeit -o most_votes_skip_first()
ms = lambda time: 1000. * sum(time.all_runs) / len(time.all_runs) / time.loops
print '{:.1%} faster'.format(ms(time1) / ms(time2) - 1)
1 loops, best of 3: 416 ms per loop
1 loops, best of 3: 408 ms per loop
0.9% faster

Optimise what takes most time

Now, let's see what takes the most time:

In [10]:
lprun(most_votes_skip_first)
Timer unit: 4.66512e-07 s

Total time: 1.33709 s
File: <ipython-input-9-82bed0413cd3>
Function: most_votes_skip_first at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def most_votes_skip_first(ac_name='BANGALORE SOUTH'):
     2         1            7      7.0      0.0      data = []
     3         1          310    310.0      0.0      handle = open(assembly_results)
     4         1          147    147.0      0.0      handle.next()  # Skip the header
     5    398703       860141      2.2     30.0      for line in handle:
     6    398702      1302605      3.3     45.4          csv_row = line.split(',')
     7    398702       702552      1.8     24.5          if csv_row[4] == ac_name:
     8        36           56      1.6      0.0              name = csv_row[6]
     9        36          172      4.8      0.0              votes = int(csv_row[11].strip())
    10        36          101      2.8      0.0              data.append([name, votes])
    11                                           
    12         1           57     57.0      0.0      return sorted(data, key=operator.itemgetter(1), reverse=True)

Reduce the number of hits

The bulk of the time is going into splitting the ','. This is called 398,702 times, once for each row.

However, we are only interested in those rows where the constituency name matches. So let's check for the name first, without bothering to split.

In [11]:
def most_votes_check_first(ac_name='BANGALORE SOUTH'):
    data = []
    handle = open(assembly_results)
    handle.next()
    for line in handle:
        # Check for a match first, before split
        if line.find(ac_name) >= 0:
            csv_row = line.split(',')
            name = csv_row[6]
            votes = int(csv_row[11].strip())
            data.append([name, votes])

    return sorted(data, key=operator.itemgetter(1), reverse=True)

time3 = %timeit -o most_votes_check_first()
print '{:.1%} faster'.format(ms(time2) / ms(time3) - 1)
1 loops, best of 3: 232 ms per loop
61.5% faster

So now, the bulk of the time is going into just checking if the ac_name is in the string.

In [12]:
lprun(most_votes_check_first)
Timer unit: 4.66512e-07 s

Total time: 0.837494 s
File: <ipython-input-11-6ceaa1572cb1>
Function: most_votes_check_first at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def most_votes_check_first(ac_name='BANGALORE SOUTH'):
     2         1            6      6.0      0.0      data = []
     3         1          312    312.0      0.0      handle = open(assembly_results)
     4         1          157    157.0      0.0      handle.next()
     5    398703       842836      2.1     46.9      for line in handle:
     6                                                   # Check for a match first, before split
     7    398702       951415      2.4     53.0          if line.find(ac_name) >= 0:
     8        36          132      3.7      0.0              csv_row = line.split(',')
     9        36           57      1.6      0.0              name = csv_row[6]
    10        36          152      4.2      0.0              votes = int(csv_row[11].strip())
    11        36           98      2.7      0.0              data.append([name, votes])
    12                                           
    13         1           60     60.0      0.0      return sorted(data, key=operator.itemgetter(1), reverse=True)

Reduce the time per hit

So how exactly can we make something as simple as line.find(ac_name) >= 0 faster?

Let's loop at the op code.

In [13]:
def check(line, ac_name):
    return line.find(ac_name) >= 0

import dis
dis.disassemble(check.func_code)
  2           0 LOAD_FAST                0 (line)
              3 LOAD_ATTR                0 (find)
              6 LOAD_FAST                1 (ac_name)
              9 CALL_FUNCTION            1
             12 LOAD_CONST               1 (0)
             15 COMPARE_OP               5 (>=)
             18 RETURN_VALUE        

There are a number of steps involved here:

  1. Load line
  2. Load find
  3. Load ac_name
  4. Then call the line.find function with ac_name as a parameter
  5. Load the constant 0
  6. Compare the result with that constant

It's generally recognised that functions in Python are slow. Let's replace this with an operator and see how well that works.

In [14]:
def check(line, ac_name):
    return ac_name in line

dis.disassemble(check.func_code)
  2           0 LOAD_FAST                1 (ac_name)
              3 LOAD_FAST                0 (line)
              6 COMPARE_OP               6 (in)
              9 RETURN_VALUE        

This appears to be doing a lot less than before. Let's try this out.

In [15]:
def most_votes_using_in(ac_name='BANGALORE SOUTH'):
    data = []
    handle = open(assembly_results)
    handle.next()
    for line in handle:
        # Use `in` instead of `.find()`
        if ac_name in line:
            csv_row = line.split(',')
            name = csv_row[6]
            votes = int(csv_row[11].strip())
            data.append([name, votes])

    return sorted(data, key=operator.itemgetter(1), reverse=True)

time4 = %timeit -o most_votes_using_in()
print '{:.1%} faster'.format(ms(time3) / ms(time4) - 1)
10 loops, best of 3: 148 ms per loop
76.7% faster

So, overall, here's our overall improvement:

In [16]:
print '{: 7,.0f} ms: original version'.format(ms(time1))
print '{: 7,.1%} faster: remove redundant if row > 0'.format(ms(time1) / ms(time2) - 1)
print '{: 7,.1%} faster: reduce # times .split() is called'.format(ms(time2) / ms(time3) - 1)
print '{: 7,.1%} faster: use operator instead of function'.format(ms(time3) / ms(time4) - 1)
print '{: 7,.1%} faster: total'.format(ms(time1) / ms(time4) - 1)
    445 ms: original version
   0.9% faster: remove redundant if row > 0
  61.5% faster: reduce # times .split() is called
  76.7% faster: use operator instead of function
 188.1% faster: total

Reject dogmas

By now, you'd wonder, "Why not use NumPy? Pandas? That's got to be faster." And it often is.

So let's try it. We'll use Pandas to achieve the same result:

In [17]:
import pandas as pd

def most_votes_pandas(ac_name='BANGALORE SOUTH'):
    data = pd.read_csv(assembly_results, low_memory=False)
    return data[data.AC_NAME.str.contains(ac_name)].sort('VOTES', ascending=False)

time5 = %timeit -o most_votes_pandas()
print '{:.1%} faster'.format(ms(time4) / ms(time5) - 1)
1 loops, best of 3: 1.03 s per loop
-87.1% faster

Shocking as it is, using Pandas is slower in this particular scenario.

In fact, we can evaluate how fast long it takes for various storage methods. Below, we create a set of dummy data files in a variety of formats (CSV, JSON, pickle, HDF5) and load them to see how long each takes.

Which is the fastest data format?

In [18]:
%run sample.data.py
In [19]:
from timeit import timeit
print '{:,.3f}s: csv.DictReader'.format(timeit("list(csv.DictReader(open('sample.data.csv')))", setup="import csv", number=1))
print '{:,.3f}s: pickle.load'.format(timeit("pickle.load(open('sample.data.pickle', 'rb'))", setup="import cPickle as pickle", number=1))
print '{:,.3f}s: json.load (array of dict)'.format(timeit("json.load(open('sample.data.json'))", setup="import json", number=1))
print '{:,.3f}s: json.load (array of arrays)'.format(timeit("json.load(open('sample.data-array.json'))", setup="import json", number=1))
print '{:,.3f}s: csv.reader'.format(timeit("list(csv.reader(open('sample.data.csv')))", setup="import csv", number=1))
print '{:,.3f}s: pandas.read_csv'.format(timeit("pd.read_csv('sample.data.csv')", setup="import pandas as pd", number=1))
print '{:,.3f}s: pandas.read_pickle'.format(timeit("pd.read_pickle('sample.data.pandas')", setup="import pandas as pd", number=1))
print '{:,.3f}s: HDF5 table'.format(timeit("pd.read_hdf('sample.data.h5', 'table')", setup="import pandas as pd", number=1))
print '{:,.3f}s: HDF5 stored'.format(timeit("pd.read_hdf('sample.data.h5', 'stored')", setup="import pandas as pd", number=1))
7.403s: csv.DictReader
5.663s: pickle.load
7.016s: json.load (array of dict)
2.226s: json.load (array of arrays)
1.217s: csv.reader
0.788s: pandas.read_csv
3.293s: pandas.read_pickle
0.833s: HDF5 table
0.292s: HDF5 stored

Only the HDF5 format comes close to the speed at which our custom algorithm performed, and even then, does not quite meet it.

So remember: the best optimiser is your head. But all else being equal, prefer HDF5 as a format.

Next topic: optimising scraping

In order to get the assembly results, we had to scrape the results from the ECI results page.

Here, the problem is not that the computations are slow.

It's not even that the network is slow.

It's that the network is unreliable at this scale.

Make your programs restartable

If you're running a large computation, two things become critical:

  1. Break it into pieces and process them separately (effectively, map-reduce-ability)
  2. Cache the results so that re-computations are avoided

The latter, caching, is the key to the "restartability" of a program -- where it can pick up an run fom where it stopped the last time. Here's a fairly typical scraping sequence:

for url in list_of_urls:
    tree = parse(url)
    # ... do more with the tree

The slowest step in this is not the computation, but fetching the URL. Can we cache it?

Cache slow operations transparently

One way is to define a method that loads the URL only if it has not already been created. Here's one simple possibility:

In [20]:
import os
from urllib import urlretrieve

def get(url):
    '''Like open(url), but cached'''
    filename = 'sample.file.name'         # We need a unique filename per URL
    if not os.path.exists(filename):
        urlretrieve(url, filename)
    return open(filename)

!rm -f sample.file.name
eci_url = 'http://eci.nic.in/eci_main1/ElectionStatistics.aspx'
%timeit -n1 -r1 get(eci_url)
%timeit -n1 -r1 get(eci_url)
1 loops, best of 1: 9.24 s per loop
1 loops, best of 1: 174 µs per loop

The file gets saved the first time. The second time, it's loaded from the disk, which is thousands of times faster.

Cache each URL as a unique key

But in our last example, the filename was not unique. We need a way of getting a unique filename for each URL. There are several options for this.

1. Use the URL as the filename. But not all URL characters are valid for files.

2. Remove special characters from the URL. But some special characters have meaning. ?x=1 is different from /x/1

3. Use a hashing function. Which begs the question, which one? And since this is a session on speed, let's time them.

Hashing efficiently

My first impulse is to use the hashlib built-in library. Here's how the various algorithms perform.

In [21]:
import hashlib
for algo in hashlib.algorithms:
    duration = timeit('hashlib.%s("%s").hexdigest()' % (algo, eci_url), setup='import hashlib')
    print '{:,.3f}s: {:s}'.format(duration, algo)
1.572s: md5
1.669s: sha1
1.840s: sha224
1.800s: sha256
2.760s: sha384
3.321s: sha512

These are the results for 1 million operations, so the time taken per hash is a few microseconds. So we really should not be optimising this, and should just randomly pick an algorithm.

The built-in hash function

However, let me just remind you of the built-in hash() function.

In [22]:
hash(eci_url)
Out[22]:
1159908681

It converts any hashable object into a signed long. We can just add 2 ** 32 to convert this into an unsigned long, which is a perfectly valid filename in almost every OS. This is what Python internally uses for dictionary keys. So it's likely to be fast. In fact, let's measure it:

In [23]:
duration = timeit('hash("%s") + 2**32')
print '{:,.3f}s: hash'.format(duration)
0.196s: hash

But remember -- the hash is not guaranteed to be identical across different runs. It is quite likely to change between Python versions, so if you're storing the cache on different machines, or running multiple Python versions, you're better off with MD5.

(This, incidentally, is yet another case of premature optimisation!)

Next topic: Parsing dates

We were exploring views to show every assembly election in India. Since we have the dates of every election, and the people that were Chief Ministers after each election, we decided to put these together as a set of visualisations.

The slowest step when processing this data was parsing the dates. For example, let's create a dummy data file that saves dates. To make it realistic, we'll use multiple formats.

In [24]:
import random
import datetime

formats = ['%d-%b-%Y', '%d %m %Y', '%b %d %Y', '%Y-%m-%d']
today = datetime.date.today()
with open('sample.dates.csv', 'w') as out:
    for x in range(100000):
        date = today - datetime.timedelta(random.randint(0, 100))
        out.write(date.strftime(random.choice(formats)) + '\n')

!head sample.dates.csv
2015-04-11
13 03 2015
10 05 2015
27 03 2015
25-Apr-2015
13 03 2015
25-Mar-2015
2015-03-16
2015-05-20
2015-06-05

Since the dates are in multiple formats, we need a good library to parse it. dateutil is the most popular flexible date processor. Let's see how long that takes.

In [25]:
from dateutil.parser import parse

# Read the data first into a list
lines = [line.strip() for line in open('sample.dates.csv')]

# Convert it into dates
def convert(lines):
    result = []
    for line in lines:
        date = parse(line, dayfirst=True)
        result.append(date)

%timeit convert(lines)
1 loops, best of 3: 8.17 s per loop

Where do you think the problem is?

The entire problem is in the date parsing function.

In [26]:
lprun(convert, lines)
Timer unit: 4.66512e-07 s

Total time: 27.5383 s
File: <ipython-input-25-10962859084c>
Function: convert at line 7

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     7                                           def convert(lines):
     8         1            6      6.0      0.0      result = []
     9    100001        97698      1.0      0.2      for line in lines:
    10    100000     58725467    587.3     99.5          date = parse(line, dayfirst=True)
    11    100000       207084      2.1      0.4          result.append(date)

Evaluate alternatives

Let's see whether we have options to make date parsing faster. Suppose we take a single date, and see how long that takes to parse.

Let's start with Pandas' to_datetime method (which internally uses dateutil.parser.parse).

In [27]:
s = pd.Series(['01-31-2012']*100000)

%timeit pd.to_datetime(s)
1 loops, best of 3: 7.57 s per loop

That wasn't fast. Let's see if the native dateutil.parser.parse is any faster.

In [28]:
%timeit [parse(v) for v in s]
1 loops, best of 3: 6.8 s per loop

But wait... in this case, we already know the string format. Can't we just use datetime.strptime()?

In [29]:
%timeit [datetime.datetime.strptime(v, '%m-%d-%Y') for v in s]
1 loops, best of 3: 1.64 s per loop

Looking much better. But what if the overhead of datetime.strptime() is high. Let's try parsing it raw.

In [30]:
%timeit [datetime.datetime(int(v[6:10]), int(v[0:2]), int(v[3:5])) for v in s]
1 loops, best of 3: 220 ms per loop

So we've gotten it to be over 20 times faster, but this is at the expense of flexibility. We can only parse one date format, and that too only if it has integers.

Cache anything that's slow

But let's take a different approach. So far, we were trying to make the function faster. But what if we tried to call it fewer times? Often, the number of dates to be parsed are few. There's no reason not to cache it.

In [31]:
def lookup(s):
    # Parse only the unique dates
    dates = {date: parse(date, dayfirst=True) for date in set(s)}
    # Look up the parsed dates
    return [dates[v] for v in s]

%timeit lookup(s)
100 loops, best of 3: 9.75 ms per loop

Now, this is a substantial improvement, and comes without any loss of flexibility.

Next topic: Loops and functions

You'll notice that I've switched over to using list comprehensions in favour of loops. That's because loops and functions both have an overhead in Python -- so where possible, it's best to avoid them. Let me show you how this works.

In [32]:
data = [random.randint(0, 100) for x in range(100000)]

def square(value):
    return value * value

def square_all(data):
    result = []
    for value in data:
        squared = square(value)
        result.append(squared)
    return result

time1 = %timeit -o square_all(data)
10 loops, best of 3: 25.1 ms per loop

Inline functions are faster

Now, let's take the same function, but avoid the overhead of calling a function. We'll inline the square function.

In [33]:
def square_all_2(data):
    result = []
    for value in data:
        squared = value * value
        result.append(squared)
    return result

time2 = %timeit -o square_all_2(data)
print '{:.1%} faster'.format(ms(time1) / ms(time2) - 1)
100 loops, best of 3: 15.3 ms per loop
62.1% faster

We're storing the squared value in a temporary variable that's not being re-used. Let's see if removing that makes a difference.

In [34]:
def square_all_3(data):
    result = []
    for value in data:
        result.append(value * value)
    return result

time3 = %timeit -o square_all_3(data)
print '{:.1%} faster'.format(ms(time2) / ms(time3) - 1)
100 loops, best of 3: 13.8 ms per loop
9.5% faster

List comprehensions are faster than loops

Now, let's remove the for loop and replace it with a list comprehension.

In [35]:
def square_all_4(data):
    return [value * value for value in data]

time4 = %timeit -o square_all_4(data)
print '{:.1%} faster'.format(ms(time3) / ms(time4) - 1)
100 loops, best of 3: 6.79 ms per loop
105.9% faster

In other words, where possible

  1. Use list comprehensions instead of loops
  2. Inline the function calls

Think in vectors

Let's try another approach to squaring these random numbers. Let's use pandas, which converts these into an array with a fixed datatype, and executes loops in C.

In [36]:
data = pd.Series(data)

time5 = %timeit -o data * data
print '{:.1%} faster'.format(ms(time4) / ms(time5) - 1)
1000 loops, best of 3: 424 µs per loop
1523.1% faster

That might be termed blazing fast. Most vector computations have the dual advantages of:

  • static typing (so no type checking or conversions are required)
  • C loops (so no overhead of handling exceptions, etc)

Here's another example where you could have a much faster result with vector computations. If you want to scale a dataset into the range [0 - 1], this piece of code is quite fast:

In [37]:
lo, hi = data.min(), data.max()
%timeit (data - lo) / (hi - lo)
1000 loops, best of 3: 1.09 ms per loop

Having read this post on Ruby being slow, I thought I'd check the same with Python. I got it running fairly fast, but there was one piece that was taking a fair bit of time: counting numbers in a range. Here's the slow version:

In [38]:
values = pd.np.random.rand(1000000)
def count(values, a, b):
    count = 0
    for x in values:
        if a <= x <= b:
            count += 1
    return count

time1 = %timeit -o count(values, .25, .75)
1 loops, best of 3: 210 ms per loop

Vector calculations are much faster than looping

Let's apply vector computations to this.

In [39]:
time2 = %timeit -o ((.25 <= values) & (values <= .75)).sum()
print '{:.1%} faster'.format(ms(time1) / ms(time2) - 1)
100 loops, best of 3: 3.05 ms per loop
6676.4% faster

But you can make it even faster

That was pretty fast, but we can go even faster when we realise that the search would be much quicker if the values are sorted. Rather than check each value, we can apply binary search on it.

Fortunately, numpy.searchsorted provides a built-in and fast binary search. When you apply that...

In [40]:
values.sort()
time3 = %timeit -o pd.np.searchsorted(values, .75) - pd.np.searchsorted(values, .25)
print '{:.1%} faster'.format(ms(time2) / ms(time3) - 1)
The slowest run took 9.04 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 3.46 µs per loop
90794.1% faster

So remember, again: the best optimiser is your head.

Next topic: writing C in Python — Cython

Cython is a project that lets you write C extensions to Python using Python -- that is, you don't need to know much C. There are many ways to install it but a distribution like Anaconda makes it much easier to install.

Once installed, you can use it in IPython like this:

In [41]:
%load_ext Cython

From this point, any cell beginning with %%cython will be interpreted using Cython, not Python.

Cython code is just like Python

Here's the code to add up numbers up to n, written in Cython and Python.

In [42]:
%%cython
def total_cython(n):
    '''Calculate the sum of all numbers up to n'''
    cdef int a = 0        # Declare the type of a as integer
    cdef int i            # Declare the type of i as integer
    for i in xrange(n):   # Now loop through all the numbers
        a += i            # ... and add them
    return a
In [43]:
def total_python(n):
    a = 0
    for i in xrange(n):
        a += i
    return a

Cython can be much faster than Python

In [44]:
%timeit total_python(100000)
%timeit total_cython(100000)
100 loops, best of 3: 6.1 ms per loop
10000 loops, best of 3: 62.1 µs per loop

In this case, Cython is almost 100 times faster than Python. A fair bit of Python's overhead lies in the flexible type system and loops.

But Cython is not always blazing fast

Let's count the number of values between a and b like before.

In [45]:
%%cython

def count_cython(values, a, b):
    cdef int count = 0
    cdef float val
    for val in values:
        if a <= val <= b:
            count += 1
    return count
In [46]:
%timeit count(values, .25, .75)
%timeit count_cython(values, .25, .75)
1 loops, best of 3: 214 ms per loop
10 loops, best of 3: 84.6 ms per loop

The speed difference between Python and Cython is still considerable, but much smaller than the almost 100x improvement we got earlier.

Numba is simpler than Cython

Numba dynamically compiles Python code and makes it faster. Its speed rivals Cython's, and it's a lot easier to use. For example, to compile the total function, use this:

In [47]:
from numba.decorators import jit

@jit
def total_numba(n):
    a = 0
    for i in range(n):
        a += i
    return a
In [48]:
%timeit total_python(100000)
%timeit total_cython(100000)
%timeit total_numba(100000)
100 loops, best of 3: 6.03 ms per loop
10000 loops, best of 3: 62.6 µs per loop
The slowest run took 277890.41 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 265 ns per loop

Notice the units. Milliseconds, microseconds, nanoseconds.

Numba is often faster than Cython

In [49]:
@jit
def count_numba(values, a, b):
    count = 0
    for i in range(len(values)):
        if a <= values[i] and values[i] <= b:
            count += 1
    return count

%timeit count(values, .25, .75)
%timeit count_cython(values, .25, .75)
%timeit count_numba(values, .25, .75)
1 loops, best of 3: 205 ms per loop
10 loops, best of 3: 83.9 ms per loop
The slowest run took 100.25 times longer than the fastest. This could mean that an intermediate result is being cached 
1000 loops, best of 3: 764 µs per loop

Summary

  1. If it's fast enough, don't optimise it.
  2. Find the slowest step first.
  3. Reduce the number of hits. Perform each operation as rarely as possible.
    • Cache the result if speed is more important than memory
    • Move slower operations inside if conditions
  4. Make the slowest operation faster
    • Python functions have an overhead. Inline if possible
    • List comprehensions are faster than for or if
    • Use Numba or Cython if static typing, etc can help
  5. Change the algorithm. This has the second-biggest impact on speed
  6. The largest impact comes from eliminating code.

Functionality is an asset. Code is a liability.

Appendix

I'm sometimes asked what I use for parallel processing. My answer is xargs. It is devilishly powerful.

Why xargs? Because

  1. It's ridiculously simple to use.
  2. It lets me distribute load across CPUs / cores in a machine, as well as across machines

Why not threading? Threads are lightweight, but when processing data, you're not processing lightweight stuff. Threading helps share data. But you don't want to share data -- you want to process each chunk of data exactly once, and get rid of it. Threads do not have a big benefit over processes in the data world.

Why not multiprocessing? The key benefit here (when compared to xargs) is that I can pass Python objects. But the disadvantage is that I cannot run the same code across multiple servers.

This is not to say that these modules are not useful. Just that it's not too relevant when processing large-scale data.