Author: Stephan Hoyer (shoyer@google.com)
Date: June 10, 2018
See the NEP for full context and details.
__array_function__
machinery¶Our goals here are:
np.concatenate
, which could involve thousands or tens of thousands of arguments to check.Note that for maximum performance, we will probably write the actual implementations of these functions (at least get_overloaded_types_and_args
and try_array_function_override
) in C.
import functools
import inspect
import numpy as np
import six
import sys
class ndarray(np.ndarray):
"""Updated version of numpy.ndarray."""
def __array_function__(self, func, types, args, kwargs):
# Cannot handle items that have __array_function__ other than our own.
for t in types:
if (hasattr(t, '__array_function__') and
t.__array_function__ is not ndarray.__array_function__):
return NotImplemented
# Arguments contain no overrides, so we can safely call the
# overloaded function again.
return func(*args, **kwargs)
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
__array_function__ implementations should be called in order on the return
values from this function.
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
if arg_type not in overloaded_types:
try:
array_function = arg_type.__array_function__
except AttributeError:
continue
overloaded_types.append(arg_type)
if array_function is not ndarray.__array_function__:
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
return overloaded_types, overloaded_args
def full_name(obj):
return f'{obj.__module__}.{obj.__qualname__}'
def attempt_augmented_error_message(error, append_message):
"""Attempt to recreate an error with an appended message."""
try:
return type(error)(error.args[0] + append_message, *error.args[1:])
except Exception:
return error
def try_array_function_override(func, relevant_arguments, args, kwargs):
# TODO: consider simplifying the interface, to only require either `types`
# (by calling __array_function__ a classmethod) or `overloaded_args` (by
# dropping `types` from the signature of __array_function__)
types, overloaded_args = get_overloaded_types_and_args(relevant_arguments)
if not overloaded_args:
return False, None
for overloaded_arg in overloaded_args:
# Note that we're only calling __array_function__ on the *first*
# occurence of each argument type. This is necessary for reasonable
# performance with a possibly long list of overloaded arguments, for
# which each __array_function__ implementation might reasonably need to
# check all argument types.
try:
result = overloaded_arg.__array_function__(
func, types, args, kwargs)
except Exception as error:
# Ensure the type of the overloaded argument ends up in the
# traceback
message = (" [while calling {!r} implementation of {!r}]"
.format(full_name(type(overloaded_arg)),
full_name(func)))
new_error = attempt_augmented_error_message(error, message)
# Would probably need to use six to do this sanely on Python 2:
# https://stackoverflow.com/questions/9157210/
raise new_error.with_traceback(error.__traceback__) from None
if result is not NotImplemented:
return True, result
raise TypeError('no implementation found for {} on types that implement '
'__array_function__: {}'
.format(func, list(map(type, overloaded_args))))
def array_function_dispatch(dispatcher):
"""Wrap a function for dispatch with the __array_function__ protocol."""
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
relevant_arguments = dispatcher(*args, **kwargs)
success, value = try_array_function_override(
new_func, relevant_arguments, args, kwargs)
if success:
return value
return func(*args, **kwargs)
return new_func
return decorator
get_overloaded_types_and_args
¶def return_self(self, *args, **kwargs):
return self
def return_not_implemented(self, *args, **kwargs):
return NotImplemented
class A:
__array_function__ = return_self
class B(A):
__array_function__ = return_self
class C(A):
__array_function__ = return_self
class D:
__array_function__ = return_self
a = A()
b = B()
c = C()
d = D()
def get_overloaded_args(relevant_args):
types, args = get_overloaded_types_and_args(relevant_args)
return args
assert get_overloaded_args([1]) == []
assert get_overloaded_args([a]) == [a]
assert get_overloaded_args([a, 1]) == [a]
assert get_overloaded_args([a, a, a]) == [a]
assert get_overloaded_args([a, d, a]) == [a, d]
assert get_overloaded_args([a, b]) == [b, a]
assert get_overloaded_args([b, a]) == [b, a]
assert get_overloaded_args([a, b, c]) == [b, c, a]
assert get_overloaded_args([a, c, b]) == [c, b, a]
class SubNDArray(ndarray):
__array_function__ = return_self
array = np.array(1).view(ndarray)
assert get_overloaded_types_and_args([array]) == ([ndarray], [])
assert get_overloaded_types_and_args([a, array, 1]) == ([A, ndarray], [a])
subarray = np.array(1).view(SubNDArray)
assert get_overloaded_args([array, subarray]) == [subarray]
assert get_overloaded_args([subarray, array]) == [subarray]
Note that functions like np.concatenate
are written in C, so we'll need to write these wrappers in C, too, unless we're OK with performance hit of doing all the wrapping logic in Python.
import numpy as np
def _broadcast_to_dispatcher(array, shape, subok=None):
return (array,)
@array_function_dispatch(_broadcast_to_dispatcher)
def broadcast_to(array, shape, subok=False):
return np.broadcast_to(array, shape, subok)
def _concatenate_dispatcher(arrays, axis=None, out=None):
for array in arrays:
yield array
if out is not None:
yield out
@array_function_dispatch(_concatenate_dispatcher)
def concatenate(arrays, axis=0, out=None):
return np.concatenate(arrays, axis=axis, out=out)
# verify that we can pickle these functions
# note: using functools.wraps and the decorator appears to be critical!
import pickle
assert pickle.loads(pickle.dumps(broadcast_to)) is broadcast_to
HANDLED_FUNCTIONS = {}
class MyArray:
def __array_function__(self, func, types, args, kwargs):
if func not in HANDLED_FUNCTIONS:
return NotImplemented
if not all(issubclass(t, MyArray) for t in types):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
def implements(numpy_function):
"""Register an __array_function__ implementation for MyArray objects."""
def decorator(func):
HANDLED_FUNCTIONS[numpy_function] = func
return func
return decorator
# dummy implementation to show how overloads work with new/unexpected arguments
@implements(concatenate)
def _(arrays):
pass
my_array = MyArray()
concatenate([my_array]) # works
concatenate([my_array], axis=0) # not supported
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-5-02956b4e04d5> in <module>() 24 my_array = MyArray() 25 concatenate([my_array]) # works ---> 26 concatenate([my_array], axis=0) # not supported <ipython-input-1-d50f18203b4d> in new_func(*args, **kwargs) 105 relevant_arguments = dispatcher(*args, **kwargs) 106 success, value = try_array_function_override( --> 107 new_func, relevant_arguments, args, kwargs) 108 if success: 109 return value <ipython-input-1-d50f18203b4d> in try_array_function_override(func, relevant_arguments, args, kwargs) 88 # Would probably need to use six to do this sanely on Python 2: 89 # https://stackoverflow.com/questions/9157210/ ---> 90 raise new_error.with_traceback(error.__traceback__) from None 91 92 if result is not NotImplemented: <ipython-input-1-d50f18203b4d> in try_array_function_override(func, relevant_arguments, args, kwargs) 78 try: 79 result = overloaded_arg.__array_function__( ---> 80 func, types, args, kwargs) 81 except Exception as error: 82 # Ensure the type of the overloaded argument ends up in the <ipython-input-5-02956b4e04d5> in __array_function__(self, func, types, args, kwargs) 7 if not all(issubclass(t, MyArray) for t in types): 8 return NotImplemented ----> 9 return HANDLED_FUNCTIONS[func](*args, **kwargs) 10 11 def implements(numpy_function): TypeError: _() got an unexpected keyword argument 'axis' [while calling '__main__.MyArray' implementation of '__main__.concatenate']
concatenate([my_array], new_arg=True) # not supported by NumPy
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-6-566bad0d5bff> in <module>() ----> 1 concatenate([my_array], new_arg=True) # not supported by NumPy <ipython-input-1-d50f18203b4d> in new_func(*args, **kwargs) 103 @functools.wraps(func) 104 def new_func(*args, **kwargs): --> 105 relevant_arguments = dispatcher(*args, **kwargs) 106 success, value = try_array_function_override( 107 new_func, relevant_arguments, args, kwargs) TypeError: _concatenate_dispatcher() got an unexpected keyword argument 'new_arg'
It's important that the overhead of __array_function__
is minimal in the typical case of no overloads.
array = np.array(1)
shape = (2,)
%timeit np.broadcast_to(array, shape)
%timeit broadcast_to(array, shape)
The slowest run took 9.35 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 4.42 µs per loop The slowest run took 5.82 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 6.46 µs per loop
arrays = [np.array([1]), np.array([2])]
%timeit np.concatenate(arrays)
%timeit concatenate(arrays)
The slowest run took 1221.88 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 869 ns per loop The slowest run took 10.83 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 3.98 µs per loop
many_arrays = [np.array([1]), np.array([2])] * 10000
%timeit np.concatenate(many_arrays)
%timeit concatenate(many_arrays)
100 loops, best of 3: 3.16 ms per loop 100 loops, best of 3: 16.7 ms per loop
arrays = [np.array([1]), np.array([2])]
stats = %prun -r for _ in range(100000): concatenate(arrays)
stats.print_stats()
800003 function calls in 0.569 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 100000 0.197 0.000 0.235 0.000 <ipython-input-1-d50f18203b4d>:22(get_overloaded_types_and_args) 100000 0.154 0.000 0.154 0.000 {built-in method numpy.core.multiarray.concatenate} 100000 0.069 0.000 0.522 0.000 <ipython-input-1-d50f18203b4d>:103(new_func) 1 0.047 0.047 0.569 0.569 <string>:1(<module>) 300000 0.038 0.000 0.038 0.000 <ipython-input-3-8b2f2296c910>:12(_concatenate_dispatcher) 100000 0.035 0.000 0.270 0.000 <ipython-input-1-d50f18203b4d>:64(try_array_function_override) 100000 0.029 0.000 0.183 0.000 <ipython-input-3-8b2f2296c910>:18(concatenate) 1 0.000 0.000 0.569 0.569 {built-in method builtins.exec} 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
<pstats.Stats at 0x7fd3456a0e80>
# other micro-benchmarks, for context
x = np.arange(10)
%timeit np.asarray(x)
%timeit x[x]
%timeit np.concatenate([x, x])
%timeit np.stack([x, x])
%timeit x.sum()
%timeit np.sum(x)
%timeit np.mean(x)
%timeit np.sin(x)
%timeit np.unique(x)
%timeit np.broadcast_to(x, (1, 10))
%timeit np.transpose(x)
%timeit np.moveaxis(x, 0, -1)
The slowest run took 16.74 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 365 ns per loop The slowest run took 44.88 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 242 ns per loop The slowest run took 59.54 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 938 ns per loop The slowest run took 6.92 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 5.39 µs per loop The slowest run took 34.62 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 1.68 µs per loop The slowest run took 11.59 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 2.72 µs per loop The slowest run took 8.03 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 7.06 µs per loop The slowest run took 50.44 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 1.22 µs per loop The slowest run took 9.14 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 6.52 µs per loop The slowest run took 6.86 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 4.58 µs per loop The slowest run took 24.69 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 644 ns per loop The slowest run took 7.18 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 5.3 µs per loop
def dummy_try_array_function_override(
func, relevant_arguments, args, kwargs):
return False, None
def dummy_dispatch(dispatcher):
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
relevant_arguments = dispatcher(*args, **kwargs)
success, value = dummy_try_array_function_override(
new_func, relevant_arguments, args, kwargs)
if success:
return value
return func(*args, **kwargs)
return new_func
return decorator
def f(x):
pass
def _dispatcher(x):
return (x,)
@dummy_dispatch(_dispatcher)
def g(x):
pass
%timeit f(1)
The slowest run took 13.60 times longer than the fastest. This could mean that an intermediate result is being cached. 10000000 loops, best of 3: 84.8 ns per loop
%timeit g(1)
The slowest run took 9.63 times longer than the fastest. This could mean that an intermediate result is being cached. 1000000 loops, best of 3: 567 ns per loop
Micro benchmark conclusions:
try_array_function_override
adds about 2-3 us of overload per function call.np.broadcast_to
), but could be significant for functions written in C (e.g., np.concatenate
).