numpy Testing

Miki Tebeka .:. 353solutions .:. Highly effective Python, Scientific Python and Go workshops

We'll explore certain caveats while testing numpy code.

TL;DR

Use np.allclose when comparing numpy arrays. Beware of nan.

In [1]:
import numpy as np

The Naive Approach

In [2]:
def test_mul():
    arr = np.array([0.0, 1.0, 1.1])
    v, expected = 1.1, np.array([0.0, 1.1, 1.21])
    assert arr * v == expected, 'bad multiplication'
    
test_mul()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-2-b3638dd4eeba> in <module>()
      4     assert arr * v == expected, 'bad multiplication'
      5 
----> 6 test_mul()

<ipython-input-2-b3638dd4eeba> in test_mul()
      2     arr = np.array([0.0, 1.0, 1.1])
      3     v, expected = 1.1, np.array([0.0, 1.1, 1.21])
----> 4     assert arr * v == expected, 'bad multiplication'
      5 
      6 test_mul()

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

This is due to the fact that when we compare two numpy arrays with == we'll get an array of boolean values comparing each element.

In [3]:
np.array([1,2,3]) == np.array([1, 1, 3])
Out[3]:
array([ True, False,  True], dtype=bool)

And the truch value of an array (as the error says) is ambiguous.

In [4]:
bool(np.array([1, 2, 3]))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-09e830fdd7d0> in <module>()
----> 1 bool(np.array([1, 2, 3]))

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

We need to use np.all to check that all elements are equal.

In [5]:
np.all([True, True, True])
Out[5]:
True

Using np.all

In [6]:
def test_mul():
    arr = np.array([0.0, 1.0, 1.1])
    v, expected = 1.1, np.array([0.0, 1.1, 1.21])
    assert np.all(arr * v == expected), 'bad multiplication'
    
test_mul()
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-6-2d3997bf9b21> in <module>()
      4     assert np.all(arr * v == expected), 'bad multiplication'
      5 
----> 6 test_mul()

<ipython-input-6-2d3997bf9b21> in test_mul()
      2     arr = np.array([0.0, 1.0, 1.1])
      3     v, expected = 1.1, np.array([0.0, 1.1, 1.21])
----> 4     assert np.all(arr * v == expected), 'bad multiplication'
      5 
      6 test_mul()

AssertionError: bad multiplication

This is due to the fact that floating points are not exact.

In [7]:
1.1 * 1.1
Out[7]:
1.2100000000000002

This is not a bug in Python but how floating points are implemented. You'll get the same result in C, Java, Go ... To overcome this we're going to use np.allclose.

BTW: If you're really intersted in floating points, read this article.

Using np.allclose

In [8]:
def test_mul():
    arr = np.array([0.0, 1.0, 1.1])
    v, expected = 1.1, np.array([0.0, 1.1, 1.21])
    assert np.allclose(arr * v, expected), 'bad multiplication'
    
test_mul()

Oh nan, Let Me Count the Ways ...

In [9]:
def test_div():
    arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])
    expected = np.array([0.5, np.nan, 1.0])
    assert np.allclose(arr1 / arr2, expected), 'bad nan'
    
test_div()
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-9-48ea1c630b01> in <module>()
      4     assert np.allclose(arr1 / arr2, expected), 'bad nan'
      5 
----> 6 test_div()

<ipython-input-9-48ea1c630b01> in test_div()
      2     arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])
      3     expected = np.array([0.5, np.nan, 1.0])
----> 4     assert np.allclose(arr1 / arr2, expected), 'bad nan'
      5 
      6 test_div()

AssertionError: bad nan

This is due to the fact the nan does not equal itself.

In [10]:
np.nan == np.nan
Out[10]:
False

To check is a number is nan we need to use np.isnan

In [11]:
np.isnan(np.inf/np.inf)
Out[11]:
True

We have two options to solve this:

  1. Convert all nan to numbers
  2. Use equal_nan argument to np.allclose

Option 1: Convert nan to Numbers

In [12]:
def test_div():
    arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])
    expected = np.array([0.5, np.nan, 1.0])
    result = arr1 / arr2
    
    result[np.isnan(result)] = 0.0
    expected[np.isnan(expected)] = 0.0
    assert np.allclose(result, expected), 'bad nan'
    
test_div()

Option 2: Use equal_nan in np.allclose

In [13]:
def test_div():
    arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])
    expected = np.array([0.5, np.nan, 1.0])
    assert np.allclose(arr1 / arr2, expected, equal_nan=True), 'bad nan'
    
test_div()