import functools import inspect import numpy as np import six import sys class ndarray(np.ndarray): """Updated version of numpy.ndarray.""" def __array_function__(self, func, types, args, kwargs): # Cannot handle items that have __array_function__ other than our own. for t in types: if (hasattr(t, '__array_function__') and t.__array_function__ is not ndarray.__array_function__): return NotImplemented # Arguments contain no overrides, so we can safely call the # overloaded function again. return func(*args, **kwargs) def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. __array_function__ implementations should be called in order on the return values from this function. """ # Runtime is O(num_arguments * num_unique_types) overloaded_types = [] overloaded_args = [] for arg in relevant_args: arg_type = type(arg) if arg_type not in overloaded_types: try: array_function = arg_type.__array_function__ except AttributeError: continue overloaded_types.append(arg_type) if array_function is not ndarray.__array_function__: index = len(overloaded_args) for i, old_arg in enumerate(overloaded_args): if issubclass(arg_type, type(old_arg)): index = i break overloaded_args.insert(index, arg) return overloaded_types, overloaded_args def full_name(obj): return f'{obj.__module__}.{obj.__qualname__}' def attempt_augmented_error_message(error, append_message): """Attempt to recreate an error with an appended message.""" try: return type(error)(error.args[0] + append_message, *error.args[1:]) except Exception: return error def try_array_function_override(func, relevant_arguments, args, kwargs): # TODO: consider simplifying the interface, to only require either `types` # (by calling __array_function__ a classmethod) or `overloaded_args` (by # dropping `types` from the signature of __array_function__) types, overloaded_args = get_overloaded_types_and_args(relevant_arguments) if not overloaded_args: return False, None for overloaded_arg in overloaded_args: # Note that we're only calling __array_function__ on the *first* # occurence of each argument type. This is necessary for reasonable # performance with a possibly long list of overloaded arguments, for # which each __array_function__ implementation might reasonably need to # check all argument types. try: result = overloaded_arg.__array_function__( func, types, args, kwargs) except Exception as error: # Ensure the type of the overloaded argument ends up in the # traceback message = (" [while calling {!r} implementation of {!r}]" .format(full_name(type(overloaded_arg)), full_name(func))) new_error = attempt_augmented_error_message(error, message) # Would probably need to use six to do this sanely on Python 2: # https://stackoverflow.com/questions/9157210/ raise new_error.with_traceback(error.__traceback__) from None if result is not NotImplemented: return True, result raise TypeError('no implementation found for {} on types that implement ' '__array_function__: {}' .format(func, list(map(type, overloaded_args)))) def array_function_dispatch(dispatcher): """Wrap a function for dispatch with the __array_function__ protocol.""" def decorator(func): @functools.wraps(func) def new_func(*args, **kwargs): relevant_arguments = dispatcher(*args, **kwargs) success, value = try_array_function_override( new_func, relevant_arguments, args, kwargs) if success: return value return func(*args, **kwargs) return new_func return decorator def return_self(self, *args, **kwargs): return self def return_not_implemented(self, *args, **kwargs): return NotImplemented class A: __array_function__ = return_self class B(A): __array_function__ = return_self class C(A): __array_function__ = return_self class D: __array_function__ = return_self a = A() b = B() c = C() d = D() def get_overloaded_args(relevant_args): types, args = get_overloaded_types_and_args(relevant_args) return args assert get_overloaded_args([1]) == [] assert get_overloaded_args([a]) == [a] assert get_overloaded_args([a, 1]) == [a] assert get_overloaded_args([a, a, a]) == [a] assert get_overloaded_args([a, d, a]) == [a, d] assert get_overloaded_args([a, b]) == [b, a] assert get_overloaded_args([b, a]) == [b, a] assert get_overloaded_args([a, b, c]) == [b, c, a] assert get_overloaded_args([a, c, b]) == [c, b, a] class SubNDArray(ndarray): __array_function__ = return_self array = np.array(1).view(ndarray) assert get_overloaded_types_and_args([array]) == ([ndarray], []) assert get_overloaded_types_and_args([a, array, 1]) == ([A, ndarray], [a]) subarray = np.array(1).view(SubNDArray) assert get_overloaded_args([array, subarray]) == [subarray] assert get_overloaded_args([subarray, array]) == [subarray] import numpy as np def _broadcast_to_dispatcher(array, shape, subok=None): return (array,) @array_function_dispatch(_broadcast_to_dispatcher) def broadcast_to(array, shape, subok=False): return np.broadcast_to(array, shape, subok) def _concatenate_dispatcher(arrays, axis=None, out=None): for array in arrays: yield array if out is not None: yield out @array_function_dispatch(_concatenate_dispatcher) def concatenate(arrays, axis=0, out=None): return np.concatenate(arrays, axis=axis, out=out) # verify that we can pickle these functions # note: using functools.wraps and the decorator appears to be critical! import pickle assert pickle.loads(pickle.dumps(broadcast_to)) is broadcast_to HANDLED_FUNCTIONS = {} class MyArray: def __array_function__(self, func, types, args, kwargs): if func not in HANDLED_FUNCTIONS: return NotImplemented if not all(issubclass(t, MyArray) for t in types): return NotImplemented return HANDLED_FUNCTIONS[func](*args, **kwargs) def implements(numpy_function): """Register an __array_function__ implementation for MyArray objects.""" def decorator(func): HANDLED_FUNCTIONS[numpy_function] = func return func return decorator # dummy implementation to show how overloads work with new/unexpected arguments @implements(concatenate) def _(arrays): pass my_array = MyArray() concatenate([my_array]) # works concatenate([my_array], axis=0) # not supported concatenate([my_array], new_arg=True) # not supported by NumPy array = np.array(1) shape = (2,) %timeit np.broadcast_to(array, shape) %timeit broadcast_to(array, shape) arrays = [np.array([1]), np.array([2])] %timeit np.concatenate(arrays) %timeit concatenate(arrays) many_arrays = [np.array([1]), np.array([2])] * 10000 %timeit np.concatenate(many_arrays) %timeit concatenate(many_arrays) arrays = [np.array([1]), np.array([2])] stats = %prun -r for _ in range(100000): concatenate(arrays) stats.print_stats() # other micro-benchmarks, for context x = np.arange(10) %timeit np.asarray(x) %timeit x[x] %timeit np.concatenate([x, x]) %timeit np.stack([x, x]) %timeit x.sum() %timeit np.sum(x) %timeit np.mean(x) %timeit np.sin(x) %timeit np.unique(x) %timeit np.broadcast_to(x, (1, 10)) %timeit np.transpose(x) %timeit np.moveaxis(x, 0, -1) def dummy_try_array_function_override( func, relevant_arguments, args, kwargs): return False, None def dummy_dispatch(dispatcher): def decorator(func): @functools.wraps(func) def new_func(*args, **kwargs): relevant_arguments = dispatcher(*args, **kwargs) success, value = dummy_try_array_function_override( new_func, relevant_arguments, args, kwargs) if success: return value return func(*args, **kwargs) return new_func return decorator def f(x): pass def _dispatcher(x): return (x,) @dummy_dispatch(_dispatcher) def g(x): pass %timeit f(1) %timeit g(1)