In [1]:
import numpy as np

断言函数

单元测试通常使用断言函数作为测试的组成部分。在进行数值计算时,我们经常遇到比较两个近似相等的浮点数这样的基本问题,由于计算机对浮点数的表示本身就不精确,所以浮点数的比较并不是那么简单。

numpy.testing包中有很多实用的工具函数考虑了浮点数的比较,可以测试前提是否成立。

1. assert_almost_equal断言近似相等

assert_almost_equal函数的作用是,如果两个数字的近似程度没有达到指定精度,就抛出异常。

In [2]:
#使用assert_almost_equal函数来检查它们是否近似相等

#指定精度,小数点后7位
print "Decimal 6", np.testing.assert_almost_equal(0.123456789, 0.123456780, decimal=7)
Decimal 6 None
In [3]:
#指定精度,小数点后8位
#抛出异常
print "Decimal 7", np.testing.assert_almost_equal(0.123456789, 0.123456780, decimal=8)
Decimal 7
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-3-4f17955884db> in <module>()
      1 #指定精度,小数点后8位
      2 #抛出异常
----> 3 print "Decimal 7", np.testing.assert_almost_equal(0.123456789, 0.123456780, decimal=8)

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_almost_equal(actual, desired, decimal, err_msg, verbose)
    488         pass
    489     if round(abs(desired - actual), decimal) != 0 :
--> 490         raise AssertionError(_build_err_msg())
    491 
    492 

AssertionError: 
Arrays are not almost equal to 8 decimals
 ACTUAL: 0.123456789
 DESIRED: 0.12345678

2. assert_approx_equal断言近似相等

如果两个数字的近似程度没有达到指定的有效数字要求,assert_approx_equal函数就抛出异常。

触发条件为:abs(actual-expected) >= 10**(significant-1)

In [4]:
# 指定8位有效数字
print "Significance 8", np.testing.assert_approx_equal(0.123456789,0.123456780, significant=8)
 Significance 8 None
In [5]:
# 指定9位有效数字
# 抛出异常
print "Significance 9", np.testing.assert_approx_equal(0.123456789,0.123456780, significant=9)
Significance 9
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-5-ad87d6f32eb2> in <module>()
      1 # 指定9位有效数字
      2 # 抛出异常
----> 3 print "Significance 9", np.testing.assert_approx_equal(0.123456789,0.123456780, significant=9)

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_approx_equal(actual, desired, significant, err_msg, verbose)
    585         pass
    586     if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)) :
--> 587         raise AssertionError(msg)
    588 
    589 def assert_array_compare(comparison, x, y, err_msg='', verbose=True,

AssertionError: 
Items are not equal to 9 significant digits:
 ACTUAL: 0.123456789
 DESIRED: 0.12345678

3. assert_array_almost_equal断言数组近似相等

如果两个数组中元素的近似程度没有达到指定的精度要求,assert_array_almost_equal函数将抛出异常。

触发条件为: |expected - actual| < 0.5 * 10^(-decimal)

In [6]:
print "Decimal 8", np.testing.assert_array_almost_equal([0, 0.123456789], [0, 0.123456780], decimal=8)
 Decimal 8 None
In [7]:
print "Decimal 9", np.testing.assert_array_almost_equal([0, 0.123456789], [0, 0.123456780], decimal=9)
Decimal 9
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-7-a30317f83231> in <module>()
----> 1 print "Decimal 9", np.testing.assert_array_almost_equal([0, 0.123456789], [0, 0.123456780], decimal=9)

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_array_almost_equal(x, y, decimal, err_msg, verbose)
    840     assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
    841              header=('Arrays are not almost equal to %d decimals' % decimal),
--> 842              precision=decimal)
    843 
    844 

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision)
    663                                 names=('x', 'y'), precision=precision)
    664             if not cond :
--> 665                 raise AssertionError(msg)
    666     except ValueError as e:
    667         import traceback

AssertionError: 
Arrays are not almost equal to 9 decimals

(mismatch 50.0%)
 x: array([ 0.         ,  0.123456789])
 y: array([ 0.        ,  0.12345678])

4. assert_array_equal断言数组相等

如果两个数组对象不相同,assert_array_equal函数将抛出异常。两个数组相等必须形状一致且元素也严格相等,允许数组存在NaN元素。

比较数组也可以使用assert_allclose函数。该函数有参数atol(absolute tolerance,绝对容差限)和rtol(relative tolerance,相对容差限)。对于两个数组a和b,将测试是否满足以下条件:| a - b | <= (atol+rtol*b)

In [8]:
print "Pass", np.testing.assert_allclose([0,0.123456789,np.nan],[0,0.123456780,np.nan], rtol=1e-7,atol=0)
 Pass None
In [9]:
print "Fail", np.testing.assert_array_equal([0,0.123456789,np.nan],[0,0.123456780,np.nan])
Fail
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-9-c3f815d83f4e> in <module>()
----> 1 print "Fail", np.testing.assert_array_equal([0,0.123456789,np.nan],[0,0.123456780,np.nan])

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_array_equal(x, y, err_msg, verbose)
    737     """
    738     assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
--> 739                          verbose=verbose, header='Arrays are not equal')
    740 
    741 def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision)
    663                                 names=('x', 'y'), precision=precision)
    664             if not cond :
--> 665                 raise AssertionError(msg)
    666     except ValueError as e:
    667         import traceback

AssertionError: 
Arrays are not equal

(mismatch 50.0%)
 x: array([ 0.      ,  0.123457,       nan])
 y: array([ 0.      ,  0.123457,       nan])

5. 数组排序

两个数组必须形状一致并且第一个数组元素严格小于第二个数组元素,否则assert_array_less函数将抛出异常。

In [10]:
# assert_array_less函数比较两个有严格顺序的数组
print "Pass", np.testing.assert_array_less([0,0.1,np.nan], [1,0.2,np.nan])
 Pass None

6. assert_equal比较对象

这里的对象不一定是NumPy数组对象,也可以是Python中的列表、元组或字典。

7. assert_string_equal字符串比较

assert_string_equal函数断言两个字符串变量完全相同。如果测试不通过,将会抛出异常并显示两个字符串之间的差别。该函数区分大小写,

In [11]:
print "Fail", np.testing.assert_string_equal("NumPy", "Numpy")
Fail
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-11-c4c5684c99d0> in <module>()
----> 1 print "Fail", np.testing.assert_string_equal("NumPy", "Numpy")

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_string_equal(actual, desired)
    981     msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
    982     if actual != desired :
--> 983         raise AssertionError(msg)
    984 
    985 

AssertionError: Differences in strings:
- NumPy?    ^
+ Numpy?    ^

8. 浮点数比较

浮点数在计算机中是以不精确的方式表示的,这给浮点数的比较带来了问题。NumPy中的assert_array_almost_equal_nulp和assert_array_max_ulp函数可以提供可靠的浮点数比较功能。ULP是Unit of Least Precision的缩写,即浮点数的最小精度单位。根据IEEE 754标准,四则运算的误差必须保持在半个ULP之内。

机器精度(machine epsilon)是指浮点运算中的相对舍入误差上界。机器精度等于ULP相对于1的值。NumPy的finfo函数可以获取机器精度。

In [12]:
# 使用finfo函数确定机器精度
eps = np.finfo(float).eps
print "Eps", eps
 Eps 2.22044604925e-16

使用assert_array_almost_equal_nulp函数比较两个近似相等的浮点数1.0和1.0+eps

In [13]:
print "one eps", np.testing.assert_array_almost_equal_nulp(1.0, 1.0+eps)
one eps None
In [14]:
print "two eps", np.testing.assert_array_almost_equal_nulp(1.0, 1.0+eps*2)
two eps
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-14-c3de8768b01e> in <module>()
----> 1 print "two eps", np.testing.assert_array_almost_equal_nulp(1.0, 1.0+eps*2)

d:\Python27_32\lib\site-packages\numpy\testing\utils.pyc in assert_array_almost_equal_nulp(x, y, nulp)
   1356             max_nulp = np.max(nulp_diff(x, y))
   1357             msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
-> 1358         raise AssertionError(msg)
   1359 
   1360 def assert_array_max_ulp(a, b, maxulp=1, dtype=None):

AssertionError: X and Y are not equal to 1 ULP (max is 2)

多ULP浮点数的比较

assert_array_max_ulp函数可以指定ULP的数量作为允许的误差上界。参数maxulp接受整数作为ULP数量的上限,默认值为1。

In [15]:
print "one eps", np.testing.assert_array_max_ulp(1.0, 1.0+eps)
 one eps 1.0
In [17]:
print "two eps", np.testing.assert_array_max_ulp(1.0, 1.0+eps*2,maxulp=2)
two eps 2.0