We'll explore certain caveats while testing numpy code.
Use np.allclose when comparing numpy arrays. Beware of nan
.
import numpy as np
def test_mul():
arr = np.array([0.0, 1.0, 1.1])
v, expected = 1.1, np.array([0.0, 1.1, 1.21])
assert arr * v == expected, 'bad multiplication'
test_mul()
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-2-b3638dd4eeba> in <module>() 4 assert arr * v == expected, 'bad multiplication' 5 ----> 6 test_mul() <ipython-input-2-b3638dd4eeba> in test_mul() 2 arr = np.array([0.0, 1.0, 1.1]) 3 v, expected = 1.1, np.array([0.0, 1.1, 1.21]) ----> 4 assert arr * v == expected, 'bad multiplication' 5 6 test_mul() ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
This is due to the fact that when we compare two numpy arrays with ==
we'll get an array of boolean values comparing each element.
np.array([1,2,3]) == np.array([1, 1, 3])
array([ True, False, True], dtype=bool)
And the truch value of an array (as the error says) is ambiguous.
bool(np.array([1, 2, 3]))
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-4-09e830fdd7d0> in <module>() ----> 1 bool(np.array([1, 2, 3])) ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
We need to use np.all to check that all elements are equal.
np.all([True, True, True])
True
def test_mul():
arr = np.array([0.0, 1.0, 1.1])
v, expected = 1.1, np.array([0.0, 1.1, 1.21])
assert np.all(arr * v == expected), 'bad multiplication'
test_mul()
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) <ipython-input-6-2d3997bf9b21> in <module>() 4 assert np.all(arr * v == expected), 'bad multiplication' 5 ----> 6 test_mul() <ipython-input-6-2d3997bf9b21> in test_mul() 2 arr = np.array([0.0, 1.0, 1.1]) 3 v, expected = 1.1, np.array([0.0, 1.1, 1.21]) ----> 4 assert np.all(arr * v == expected), 'bad multiplication' 5 6 test_mul() AssertionError: bad multiplication
This is due to the fact that floating points are not exact.
1.1 * 1.1
1.2100000000000002
This is not a bug in Python but how floating points are implemented. You'll get the same result in C, Java, Go ... To overcome this we're going to use np.allclose.
BTW: If you're really intersted in floating points, read this article.
def test_mul():
arr = np.array([0.0, 1.0, 1.1])
v, expected = 1.1, np.array([0.0, 1.1, 1.21])
assert np.allclose(arr * v, expected), 'bad multiplication'
test_mul()
def test_div():
arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])
expected = np.array([0.5, np.nan, 1.0])
assert np.allclose(arr1 / arr2, expected), 'bad nan'
test_div()
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) <ipython-input-9-48ea1c630b01> in <module>() 4 assert np.allclose(arr1 / arr2, expected), 'bad nan' 5 ----> 6 test_div() <ipython-input-9-48ea1c630b01> in test_div() 2 arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0]) 3 expected = np.array([0.5, np.nan, 1.0]) ----> 4 assert np.allclose(arr1 / arr2, expected), 'bad nan' 5 6 test_div() AssertionError: bad nan
This is due to the fact the nan
does not equal itself.
np.nan == np.nan
False
To check is a number is nan
we need to use np.isnan
np.isnan(np.inf/np.inf)
True
We have two options to solve this:
nan
to numbersequal_nan
argument to np.allclose
nan
to Numbers¶def test_div():
arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])
expected = np.array([0.5, np.nan, 1.0])
result = arr1 / arr2
result[np.isnan(result)] = 0.0
expected[np.isnan(expected)] = 0.0
assert np.allclose(result, expected), 'bad nan'
test_div()
equal_nan
in np.allclose
¶def test_div():
arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])
expected = np.array([0.5, np.nan, 1.0])
assert np.allclose(arr1 / arr2, expected, equal_nan=True), 'bad nan'
test_div()