Memoization

Memoization is a general technique for speeding up recursive algorithms. Specifically, in this technique if there is some function f taking in some inputs p, we maintain a lookup table on the side indexed by possible such inputs. If f(p) has already been calculated in the past, we return f(p) from the lookup table. Otherwise, we compute it from scratch, insert it into the lookup table, then return it. We illustrate this technique below using an example.

We will show how we can use memoization to make significant savings in running time.

Example 1: Packing items into boxes

You have a box that can hold $W$ kilos (if you try to put more, it breaks). You are given a list $L$ of items $[ [w_0,v_0],\ldots, [w_{n-1},v_{n-1}]]$ where the $i^{th}$ item weighs $w_i$ kilos and is worth $v_i$ birr if you sell it in the market. What is the maximum amount of value you can fit in your box without the box breaking? There is only one of each item (you can't put multiple copies of the same item in your box).

We start by solving this with a simple recursive procedure. One way to think about this question is as follows:

  • If there was only one item that has weight $w$ and has value $v$, then our decision is simple: if $W \geq w$ then we can fit the item and get $v$ value, and otherwise we get zero value (no items).

  • If there are two items with weight/value $(w_0,v_0)$ and $(w_1,v_1)$ then we might have a choice to make. It's often good in such cases to split a complicated choice to a sequence of simple choices. So let's start with the choice of whether or not to put the $0$th item in the box. If we do take it, then we'll get $v_0$ value, and be left with a box that can only fit $W-w_0$ kilos more. If we don't then we get zero value, and are left with $W$ kilos to pack.

  • Therefore if $\mathtt{mostValue}(W,((w_0,v_0),\ldots,(w_{n-1},v_{n-1})))$ is the maximum amount of value that we can have given a budget of $W$ and weights/values $((w_0,v_0),\ldots, (w_n,v_n))$ then we have the following equation:

$\mathtt{mostValue}(W,((w_0,v_0),\ldots,(w_{n-1},v_{n-1}))) = \max \{ 0 + \mathtt{mostValue}(W,((w_1,v_1),\ldots,(w_{n-1},v_{n-1}))) , v_0 + \mathtt{mostValue}(W-w_0,((w_1,v_1),\ldots,(w_{n-1},v_{n-1}))) \}$.

Can you see why?

This now suggests a simple recursive algorithm for the mostValue function:

In [4]:
%run  'boaz_utils.ipynb'
Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib
In [3]:
def mostValue(W,L):
    """returns the maximum amount of value we can have with W kilos packing items listed in L, 
    where L is a tuple/list containing pairs (w,v) weight/value for every item."""
    
    step_pc() # ignore for now
    if not L:
        # if L is empty then we can't have any fun
        return 0
    value_if_skip_first_item = mostValue(W,L[1:]) # the amount of value we can have if we skip first item 
    if W<L[0][0]: # if we can't afford to pack the first item then we have no choices to make
        return value_if_skip_first_item
    # otherwise we will check both options and see what's the maximum value we can get
    return max(value_if_skip_first_item, L[0][1]+mostValue(W-L[0][0],L[1:]))

Here are some examples:

  • $\mathtt{mostValue}(1,((2,10),))=0$: We can't fit the first item so we get no value (:
  • $\mathtt{mostValue}(2,((2,10),))=10$
  • $\mathtt{mostValue}(2,((2,5),(2,10)))=10$: We can afford both items but the second one gives us more value
  • $\mathtt{mostValue}(4,((1,6),(2,5),(3,6),(2,10)))=16$: The best combination is to pack the first and fourth items
In [5]:
mostValue(1,((2,10),))
Out[5]:
0
In [6]:
mostValue(2,((2,10),))
Out[6]:
10
In [7]:
mostValue(2,((2,5),(2,10)))
Out[7]:
10
In [8]:
mostValue(4,((1,6),(2,5),(3,6),(2,10)))
Out[8]:
16

So far so good, but what happens if we try a larger example?

In [9]:
n = 23
%timeit -n1 -r1 print("mostValue is .."),mostValue(n*n,[(i,i) for i in range(n)])
mostValue is ..
6.98 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

This starts to be a little slow. More generally, lets see how the time grows as a function of n

In [11]:
inputs = [((n*n),[(i,i) for i in range(n)]) for n in range(1,24)];
In [12]:
(outputs,curve,lengths,times,steps) = timer(mostValue,inputs,inlen = lambda i: len(i[1]) , doplot_logscale=False);
.......................plot_steps: True
0.438 micro-seconds per step
Curve (steps): $2\exp(1.01n)$

We can see that when we increase the input length by one we take about twice as much time. If we extrapolate from these data, we get that if we tried this for 40 parties (which might not be too much for some people), then the total amount of time it would take us to compute the sum would be about

In [13]:
t = curve(40)
"{0:,g} seconds".format(t)
Out[13]:
'1.26973e+12 seconds'

Which equals

In [14]:
"{:,} years".format(int(t/(365*24*60*60)))
Out[14]:
'40,263 years'

ouch!

But do we really need to spend so much time? For example, when the input length was

In [27]:
lengths[-1]
Out[27]:
23

we made

In [15]:
print("%d calls to the function" % steps[-1])
16777215 calls to the function

But if we think about it, if our original input was a list $L$ of $n$ items, and a weight $W$, every recursive call to the function mostValue was made with a pair $(W',L')$ such that $0 \leq W' \leq W$ and $L'$ is a suffix of $L$ (i.e., the last $i$ elements of $L$ for some $i\leq n$).

Since there are $W+1$ numbers between $0$ to $W$ and there are at most $n+1$ suffixes for an $n$ length list, we see that the total number of possible inputs we could have used is $(W+1)\times (n+1)$. But note that in our case this was

In [16]:
"{:,} possible inputs".format((inputs[-1][0]+1)*(len(inputs[-1][1])+1))
Out[16]:
'12,720 possible inputs'

So we made a huge number of unnecessary calls to the function. If we could simply remember the result of these calls, rather than executing them again, then we could potentially save a lot of time. This is what we do next. For starters, we use a list mem that will record for every $W'$ and $i$, the result of the function on the input which is $W'$ and the last $i$ elements of $L$.

In [19]:
def mostValueMemo(W, L, mem):
    """returns the maximum amount of value we can pack fitting in W kilos using only items listed in L, 
    where L is a tuple/list containing pairs (w,v) of cost/fun for every party."""
    step_pc() 
    
    # base case
    if len(L)==0:
        return 0
    
    # check the memory
    elif mem[W][len(L)]!=-1:
        return mem[W][len(L)]
    
    # recursive case
    A = mostValueMemo(W,L[1:],mem) # the amount of fun we can have if we skip first party 
    if W<L[0][0]: # if we can't afford to attend the first party then we have no choices to make
        mem[W][len(L)] = A
        return A
    # otherwise we will check both options and see what's the maximum fun we can have
    mem[W][len(L)] =  max(A, L[0][1]+mostValueMemo(W-L[0][0],L[1:],mem))
    return mem[W][len(L)]

def mostValueFast(W,L):
    mem = []
    for i in range(W+1):
        mem += [[-1]*(len(L)+1)]
    return mostValueMemo(W, L, mem)

Lets see that we get the same answers:

In [20]:
mostValueFast(1,((2,10),))
Out[20]:
0
In [21]:
mostValueFast(2,((2,10),))
Out[21]:
10
In [22]:
mostValueFast(2,((2,5),(2,10)))
Out[22]:
10
In [23]:
mostValueFast(4,((1,6),(2,5),(3,6),(2,10)))
Out[23]:
16

Now let's see if we indeed get a speed saving:

In [26]:
inputs = [((n*n),[(i,i) for i in range(n)]) for n in range(1,100)];
In [27]:
(outputs_mem,_1,_2,_3,_4) = timer(mostValueFast,inputs,inlen = lambda i: len(i[1]) ,doplot_logscale = False);
...................................................................................................plot_steps: True
0.791 micro-seconds per step
Curve (steps): $0.4n^3$
In [28]:
outputs == outputs_mem[:len(outputs)] # sanity check
Out[28]:
True

In particular for $n=40$ this took us about $35$ milliseconds instead of $40,000$ years, not bad...

We typically try to avoid global variables, and so let's rewrite the program without them by adding an extra parameter to the recursion:

Another example: making change.

Here in Ethiopia we have a few different kinds of coins: 1, 5, 10, 25, 50, and 100 cents. Suppose someone buys a piece of gum for 88 cents and gives the shop owner one birr. Then they have to get change: 12 cents. How many ways are there for the shop owner to give this change? There are four ways:

  • 1+1+1+1+1+1+1+1+1+1+1+1
  • 1+1+1+1+1+1+1+5
  • 1+1+5+5
  • 1+1+10

Not every country has the same coin system though. For example, in the European Union (EU) their coin types are worth 1, 2, 5, 10, 20, 50, 100, and 200 cents. This changes the number of ways to give change in some cases. For example, to make change for two cents, there's only one way in Ethiopia: 1+1. But in the EU, there are two ways: 1+2, and also just giving a single 2-cent coin.

Our task now is to make a function countWays so that countWays(L, n) is the number of ways to make change for n cents using the list of coins L. The first example above corresponds to countWays([1,5,10,25,50,100], 12), which should return 4.

In [39]:
def countWays(L, n):
    # base case: there are no coins to use
    if L==[]:
        if n == 0:
            return 1
        else:
            # if there are no coins you can use, it's impossible to make change
            return 0
    
    # recursive case
    else:
        # one option we have is not using coin L[0] at all
        result = countWays(L[1:], n)
        if n>=L[0]:
            # if n>=L[0], another option is using coin L[0] at least once
            result += countWays(L, n-L[0])
        return result
In [40]:
countWays([1,5,10,25,50,100],12)
Out[40]:
4

Unfortunately this code can be quite slow.

In [41]:
countWays([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],10)
Out[41]:
8436285

We can speed up the computation by using memoization.

In [42]:
# how many ways are there to make change for n cents using only the coins in L[i:]?
def countWaysMemo(L, val, mem):
    # base case: there are no coins to use
    if len(L) == 0:
        if val == 0:
            return 1
        else:
            # if there are no coins you can use, it's impossible to make change
            return 0
    elif mem[len(L)][val]!=-1:
        return mem[len(L)][val]
    else:
         # one option we have is not using coin L[0] at all
        mem[len(L)][val] = countWaysMemo(L[1:], val, mem)
        if val>=L[0]:
            # if n>=L[0], another option is using coin L[0] at least once
            mem[len(L)][val] += countWaysMemo(L, val-L[0], mem)
        return mem[len(L)][val]

def countWaysFast(L, n):
    mem = []
    for i in range(len(L)+1):
        mem += [[-1]*(n+1)]
    return countWaysMemo(L, n, mem)
In [43]:
countWaysFast([1,5,10,25,50,100],12)
Out[43]:
4
In [44]:
countWaysFast([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],10)
Out[44]:
92561040

How fast is the memoized version?

Let $m$ denote the length of the original list $L$. Then as we go through the recursion, then length of $L$ is always between $0$ and $m$, and $val$ is always between $0$ and $n$. Thus the total number of possible inputs to the memoized function is about $nm$. Each call only does $O(m)$ work (creating the list L[1:] takes $m$ steps), so in total the time is $O(nm^2)$.

Bonus: think about how to make the runtime $O(nm)$ instead of $O(nm^2)$. A hint is to avoid having to create the list L[1:]. What if instead we started with a recursive function countWays(i, L, n) which computes the number of ways to make $n$ cents using only the coins in L[i:], and then we memoized that?