Memoization in Python

Although we always want our code to be fast as lightning, it is sometimes inevitable to have some tasks, implemented as functions, that are computationally expensive. Welcome to the real world. If the function needs to be run only once for a specific input, we can just run it and wait, hopefully not too long. However, if it needs to be run with the same input many times, it may be useful to store the output of the function to avoid wasting clock cycles.

One way to avoid running the same function with the same input multiple times is to save the output to a file. For subsuquent runs, the output file is read and the result is returned without running the entire function. Although this approach may be useful and good enough for some cases, we will implement more generalized and elegant approach using memoization and Python function decorators.

Memoization is just a fancy word for keeping the result of expensive function calls and restoring it for subsequent function calls with the same input. We assume that our function must always evaluate the same result for a given input. In other words, it should not have any side effect on the output. For example, while the function len(x) is pure (i.e. it returns the same value given the same list), the function random() is not (i.e. it returns different value every time it is called). Let's give another impure function example.

In [1]:
x = 1

def add_x(y):
    return x + y

# let's run add_x with input 1
print add_x(1)

# let's change the value for x and run add_x(1) again
x = 2
print add_x(1)

As you see, the function add_x is called with the same argument, yet returns different values.

Although there are some pure programming languages (e.g. Haskell), Python is not one of them and allows the user to define impure functions. For our memoization implementation, we assume that our function does not have any side effect.

Fibonacci numbers

Fibonacci numbers are the series of numbers 0, 1, 1, 2, 3, 5, 8, 13, 21, 35, 55, ... They are recursively defined as

$$ F(0) = 0\\ F(1) = 1\\ F(n) = F(n-1) + F(n-2), n > 2 $$

Although we can efficiently compute the Fibonacci number for a given $n$, iteratively or using closed-form expression, we will use the computationally expensive version. After all, the whole point is to make it faster using memoization technique.

In [2]:
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

This function looks OK at first, but it performs terrible for large values of n. Why? Consider case $n=5$. fib(5) will call fib(4) and fib(3) . These two functions will recursively call fib by decrementing $n$. fib(4) will call fib(3) and fib(2), and fib(3) will call fib(2) and fib(1). Can you see the problem now? The function fib(3) would be called twice, fib(2) would be called three times! This gets much worse for larger values of $n$. Let's modify the function with print statements and call it for $n=5$.

In [3]:
def fib(n, msg=False):
    if msg:
       print "Calling fib(%d)" % n
    if n < 2:
        return n
    return fib(n-1, msg) + fib(n-2, msg)

fib(5, msg=True)
Calling fib(5)
Calling fib(4)
Calling fib(3)
Calling fib(2)
Calling fib(1)
Calling fib(0)
Calling fib(1)
Calling fib(2)
Calling fib(1)
Calling fib(0)
Calling fib(3)
Calling fib(2)
Calling fib(1)
Calling fib(0)
Calling fib(1)

As you see, the function fib was called with same arguments again and again. Since fib is a pure function, it returns the same value if called with the same argument. Perfect for caching results and use already available results when needed!

Function decorators

In Python, functions are first-class citizens which means that they can be passed as arguments to or returned by other functions.

Suppose we have a function called factorial that returns $n!$ for given $n$.

In [4]:
def factorial(n):
    return reduce(lambda x, y: x*y, xrange(1, n+1))

# 5! = 120

Suppose we want a function to print messages when it is called. We can have something like this.

In [5]:
def bad_factorial(n):
    print "Calling factorial(%d)" % n
    return reduce(lambda x, y: x*y, xrange(1, n+1))

# 5! = 120
Calling factorial(5)

That was easy, right? However, what if we have another function that we want to print message when it is called. Just modify the function to have a print statement. Another one? You see, this is getting boring. There must be some neat way to do this.

As we said earlier, functions can be passed as arguments and returned by other functions. Let's define a function called logger which gets a function as argument func and modifies it so that it will print a message when it is called.

In [6]:
def logger(func):
    def func_with_msg(*args, **kwargs):
        print "Calling %s(%s,%s)" % (func.__name__, args, kwargs)
        return func(*args, **kwargs)
    return func_with_msg

good_factorial = logger(factorial)

# Call modified factorial function
Calling factorial((5,),{})

*args and **kwargs enable us to define a function that accepts arbitrary number of parameters. You can read more here. Now we can use call logger with any of our functions that needs to print message when called.

The shortcut for applying a decorator function like logger to any function is to prepend the function with the symbol @ and the decorator function name.

In [7]:
def better_factorial(n):
    return reduce(lambda x, y: x*y, xrange(1, n+1))
print better_factorial(5)
Calling better_factorial((5,),{})

which is equivalent to

In [8]:
Calling factorial((5,),{})


A closure in Python is a function that has access to its enclosing scope.

In [9]:
def makeInc(x):
    def inc(y):
        # inc has access to the x which is defined in makeInc (the enclosing
        # scope)
        return x + y
    return inc

inc5 = makeInc(5)
inc10 = makeInc(10)

print inc5(3)
print inc10(3)

And the interesting part is that the function has access to its enclosing scope even the parent function is no longer in the memory.

In [10]:
<function __main__.makeInc>
In [11]:
del makeInc
In [12]:
print inc5(2)

Implementing memoization

Now we can use function decorators and closures to build the memoization for any given function. As an example, let's use fibonacci function again.

In [13]:
def not_working_memoize(func):
    # Make memoization available for a given function func
    memo = {}
    def wrapper(*args, **kwargs):
        if (args, kwargs) not in memo:
           memo[(args, kwargs)] = func(*args, **kwargs)
        return memo[(args)]
    return wrapper
In [14]:
def memoized_fibonacci(n):
    if n < 2:
       return 1
    return memoized_fibonacci(n-1) + memoized_fibonacci(n-2)

TypeError                                 Traceback (most recent call last)
<ipython-input-14-e6008084baab> in <module>()
      5     return memoized_fibonacci(n-1) + memoized_fibonacci(n-2)
----> 7 memoized_fibonacci(10)

<ipython-input-13-d89f8e213b23> in wrapper(*args, **kwargs)
      3     memo = {}
      4     def wrapper(*args, **kwargs):
----> 5         if (args, kwargs) not in memo:
      6            memo[(args, kwargs)] = func(*args, **kwargs)
      7         return memo[(args)]

TypeError: unhashable type: 'dict'

This will not work. The reason is that we are trying to build memo dictionary and the key for dictionary is the tuple (args, kwargs) which are a tuple and a dictionary. We are getting TypeError: unhashable type: 'dict' error since kwargs is a dictionary. To make it work, we convert the dictionary kwargs into list of tuples where each tuple is a key and its value from the dictionary.

In [15]:
debug = True
def hopefully_working_memoize(func):
    memo = {}
    def wrapper(*args, **kwargs):
        k = tuple(list(args) + kwargs.items())
        if debug:
           print "key", k,
           if k not in memo:
              print "not",
           print "in memory"
        if k not in memo:
           memo[k] = func(*args, **kwargs)
        return memo[k]
    return wrapper
In [16]:
def memoized_fibonacci(n):
    if n < 2:
       return n
    return memoized_fibonacci(n-1) + memoized_fibonacci(n-2)

key (8,) not in memory
key (7,) not in memory
key (6,) not in memory
key (5,) not in memory
key (4,) not in memory
key (3,) not in memory
key (2,) not in memory
key (1,) not in memory
key (0,) not in memory
key (1,) in memory
key (2,) in memory
key (3,) in memory
key (4,) in memory
key (5,) in memory
key (6,) in memory

As you see above, initial function calls are actually executed since they have not been executed and stored in the memory before. As the function calls itself recursively for smaller n values, it needs the fibonacci(n) which is computed before so it can be accessed easily without doing the computation again.

And if we ever call memoized_fibonacci(n) again for the same n, we immediately get the result.

In [17]:
key (8,) in memory

Let's make a few tests. The first call is for the inefficient implementation of the Fibonacci numbers. The call for even small n values takes quite some time.

In [18]:
import timeit
timeit.timeit('fib(9)', setup="from __main__ import fib")

Let's call the memoized function with the same value of n.

In [19]:
debug = False
timeit.timeit('memoized_fibonacci(9)', setup="from __main__ import memoized_fibonacci")

The nice thing about using memoization as described is that it is general and can be used for any computationally-expensive function easily!

About this post

This post is powered by IPython Notebook To use IPython notebook with Emacs, try Emacs IPython Notebook package. The original notebook file is available here