NEP 18 example implementation

Author: Stephan Hoyer ([email protected])

Date: June 10, 2018

See the NEP for full context and details.

Implementation of __array_function__ machinery

Our goals here are:

  1. Correctness
  2. Performance for the typical case of no overloads
  3. Performance for large numbers of arguments
    • This is important for overloading functions like np.concatenate, which could involve thousands or tens of thousands of arguments to check.

Note that for maximum performance, we will probably write the actual implementations of these functions (at least get_overloaded_types_and_args and try_array_function_override) in C.

In [0]:
import functools
import inspect
import numpy as np
import six
import sys


class ndarray(np.ndarray):
    """Updated version of numpy.ndarray."""
    def __array_function__(self, func, types, args, kwargs):
        # Cannot handle items that have __array_function__ other than our own.
        for t in types:
            if (hasattr(t, '__array_function__') and
                    t.__array_function__ is not ndarray.__array_function__):
                return NotImplemented

        # Arguments contain no overrides, so we can safely call the
        # overloaded function again.
        return func(*args, **kwargs)


def get_overloaded_types_and_args(relevant_args):
    """Returns a list of arguments on which to call __array_function__.
    
    __array_function__ implementations should be called in order on the return
    values from this function.
    """
    # Runtime is O(num_arguments * num_unique_types)
    overloaded_types = []
    overloaded_args = []
    for arg in relevant_args:
        arg_type = type(arg)
        if arg_type not in overloaded_types:
            try:
                array_function = arg_type.__array_function__
            except AttributeError:
                continue

            overloaded_types.append(arg_type)

            if array_function is not ndarray.__array_function__:
                index = len(overloaded_args)
                for i, old_arg in enumerate(overloaded_args):
                    if issubclass(arg_type, type(old_arg)):
                        index = i
                        break
                overloaded_args.insert(index, arg)

    return overloaded_types, overloaded_args


def full_name(obj):
    return f'{obj.__module__}.{obj.__qualname__}'
  

def attempt_augmented_error_message(error, append_message):
    """Attempt to recreate an error with an appended message."""
    try:
        return type(error)(error.args[0] + append_message, *error.args[1:])
    except Exception:
        return error
  

def try_array_function_override(func, relevant_arguments, args, kwargs):
    # TODO: consider simplifying the interface, to only require either `types`
    # (by calling __array_function__ a classmethod) or `overloaded_args` (by
    # dropping `types` from the signature of __array_function__)
    types, overloaded_args = get_overloaded_types_and_args(relevant_arguments)
    if not overloaded_args:
        return False, None

    for overloaded_arg in overloaded_args:
        # Note that we're only calling __array_function__ on the *first*
        # occurence of each argument type. This is necessary for reasonable
        # performance with a possibly long list of overloaded arguments, for
        # which each __array_function__ implementation might reasonably need to
        # check all argument types.
        try:
            result = overloaded_arg.__array_function__(
                func, types, args, kwargs)
        except Exception as error:
            # Ensure the type of the overloaded argument ends up in the
            # traceback
            message = (" [while calling {!r} implementation of {!r}]"
                       .format(full_name(type(overloaded_arg)),
                               full_name(func)))
            new_error = attempt_augmented_error_message(error, message)
            # Would probably need to use six to do this sanely on Python 2:
            # https://stackoverflow.com/questions/9157210/
            raise new_error.with_traceback(error.__traceback__) from None

        if result is not NotImplemented:
            return True, result

    raise TypeError('no implementation found for {} on types that implement '
                    '__array_function__: {}'
                    .format(func, list(map(type, overloaded_args))))


def array_function_dispatch(dispatcher):
    """Wrap a function for dispatch with the __array_function__ protocol."""
    def decorator(func):
        @functools.wraps(func)
        def new_func(*args, **kwargs):
            relevant_arguments = dispatcher(*args, **kwargs)
            success, value = try_array_function_override(
                new_func, relevant_arguments, args, kwargs)
            if success:
                return value
            return func(*args, **kwargs)
        return new_func
    return decorator

Unit tests for get_overloaded_types_and_args

In [0]:
def return_self(self, *args, **kwargs):
    return self

def return_not_implemented(self, *args, **kwargs):
    return NotImplemented

class A:
    __array_function__ = return_self

class B(A):
    __array_function__ = return_self

class C(A):
    __array_function__ = return_self

class D:
    __array_function__ = return_self

a = A()
b = B()
c = C()
d = D()

def get_overloaded_args(relevant_args):
    types, args = get_overloaded_types_and_args(relevant_args)
    return args

assert get_overloaded_args([1]) == []
assert get_overloaded_args([a]) == [a]
assert get_overloaded_args([a, 1]) == [a]
assert get_overloaded_args([a, a, a]) == [a]
assert get_overloaded_args([a, d, a]) == [a, d]
assert get_overloaded_args([a, b]) == [b, a]
assert get_overloaded_args([b, a]) == [b, a]
assert get_overloaded_args([a, b, c]) == [b, c, a]
assert get_overloaded_args([a, c, b]) == [c, b, a]

class SubNDArray(ndarray):
    __array_function__ = return_self

array = np.array(1).view(ndarray)
assert get_overloaded_types_and_args([array]) == ([ndarray], [])
assert get_overloaded_types_and_args([a, array, 1]) == ([A, ndarray], [a])

subarray = np.array(1).view(SubNDArray)
assert get_overloaded_args([array, subarray]) == [subarray]
assert get_overloaded_args([subarray, array]) == [subarray]

Example function implementations

Note that functions like np.concatenate are written in C, so we'll need to write these wrappers in C, too, unless we're OK with performance hit of doing all the wrapping logic in Python.

In [0]:
import numpy as np


def _broadcast_to_dispatcher(array, shape, subok=None):
    return (array,)

@array_function_dispatch(_broadcast_to_dispatcher)
def broadcast_to(array, shape, subok=False):
    return np.broadcast_to(array, shape, subok)


def _concatenate_dispatcher(arrays, axis=None, out=None):
    for array in arrays:
        yield array
    if out is not None:
        yield out
    
@array_function_dispatch(_concatenate_dispatcher)
def concatenate(arrays, axis=0, out=None):
    return np.concatenate(arrays, axis=axis, out=out)
In [0]:
# verify that we can pickle these functions
# note: using functools.wraps and the decorator appears to be critical!
import pickle
assert pickle.loads(pickle.dumps(broadcast_to)) is broadcast_to

MyArray implementations

In [5]:
HANDLED_FUNCTIONS = {}

class MyArray:
    def __array_function__(self, func, types, args, kwargs):
        if func not in HANDLED_FUNCTIONS:
            return NotImplemented
        if not all(issubclass(t, MyArray) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

def implements(numpy_function):
    """Register an __array_function__ implementation for MyArray objects."""
    def decorator(func):
        HANDLED_FUNCTIONS[numpy_function] = func
        return func
    return decorator

    
# dummy implementation to show how overloads work with new/unexpected arguments
@implements(concatenate)
def _(arrays):
    pass

my_array = MyArray()
concatenate([my_array])  # works
concatenate([my_array], axis=0)  # not supported
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-02956b4e04d5> in <module>()
     24 my_array = MyArray()
     25 concatenate([my_array])  # works
---> 26 concatenate([my_array], axis=0)  # not supported

<ipython-input-1-d50f18203b4d> in new_func(*args, **kwargs)
    105             relevant_arguments = dispatcher(*args, **kwargs)
    106             success, value = try_array_function_override(
--> 107                 new_func, relevant_arguments, args, kwargs)
    108             if success:
    109                 return value

<ipython-input-1-d50f18203b4d> in try_array_function_override(func, relevant_arguments, args, kwargs)
     88             # Would probably need to use six to do this sanely on Python 2:
     89             # https://stackoverflow.com/questions/9157210/
---> 90             raise new_error.with_traceback(error.__traceback__) from None
     91 
     92         if result is not NotImplemented:

<ipython-input-1-d50f18203b4d> in try_array_function_override(func, relevant_arguments, args, kwargs)
     78         try:
     79             result = overloaded_arg.__array_function__(
---> 80                 func, types, args, kwargs)
     81         except Exception as error:
     82             # Ensure the type of the overloaded argument ends up in the

<ipython-input-5-02956b4e04d5> in __array_function__(self, func, types, args, kwargs)
      7         if not all(issubclass(t, MyArray) for t in types):
      8             return NotImplemented
----> 9         return HANDLED_FUNCTIONS[func](*args, **kwargs)
     10 
     11 def implements(numpy_function):

TypeError: _() got an unexpected keyword argument 'axis' [while calling '__main__.MyArray' implementation of '__main__.concatenate']
In [6]:
concatenate([my_array], new_arg=True)  # not supported by NumPy
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-566bad0d5bff> in <module>()
----> 1 concatenate([my_array], new_arg=True)  # not supported by NumPy

<ipython-input-1-d50f18203b4d> in new_func(*args, **kwargs)
    103         @functools.wraps(func)
    104         def new_func(*args, **kwargs):
--> 105             relevant_arguments = dispatcher(*args, **kwargs)
    106             success, value = try_array_function_override(
    107                 new_func, relevant_arguments, args, kwargs)

TypeError: _concatenate_dispatcher() got an unexpected keyword argument 'new_arg'

Micro benchmarks

It's important that the overhead of __array_function__ is minimal in the typical case of no overloads.

In [7]:
array = np.array(1)
shape = (2,)
%timeit np.broadcast_to(array, shape)
%timeit broadcast_to(array, shape)
The slowest run took 9.35 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 4.42 µs per loop
The slowest run took 5.82 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.46 µs per loop
In [8]:
arrays = [np.array([1]), np.array([2])]
%timeit np.concatenate(arrays)
%timeit concatenate(arrays)
The slowest run took 1221.88 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 869 ns per loop
The slowest run took 10.83 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.98 µs per loop
In [9]:
many_arrays = [np.array([1]), np.array([2])] * 10000
%timeit np.concatenate(many_arrays)
%timeit concatenate(many_arrays)
100 loops, best of 3: 3.16 ms per loop
100 loops, best of 3: 16.7 ms per loop
In [10]:
arrays = [np.array([1]), np.array([2])]
stats = %prun -r for _ in range(100000): concatenate(arrays)
stats.print_stats()
          800003 function calls in 0.569 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   100000    0.197    0.000    0.235    0.000 <ipython-input-1-d50f18203b4d>:22(get_overloaded_types_and_args)
   100000    0.154    0.000    0.154    0.000 {built-in method numpy.core.multiarray.concatenate}
   100000    0.069    0.000    0.522    0.000 <ipython-input-1-d50f18203b4d>:103(new_func)
        1    0.047    0.047    0.569    0.569 <string>:1(<module>)
   300000    0.038    0.000    0.038    0.000 <ipython-input-3-8b2f2296c910>:12(_concatenate_dispatcher)
   100000    0.035    0.000    0.270    0.000 <ipython-input-1-d50f18203b4d>:64(try_array_function_override)
   100000    0.029    0.000    0.183    0.000 <ipython-input-3-8b2f2296c910>:18(concatenate)
        1    0.000    0.000    0.569    0.569 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


Out[10]:
<pstats.Stats at 0x7fd3456a0e80>
In [11]:
# other micro-benchmarks, for context
x = np.arange(10)
%timeit np.asarray(x)
%timeit x[x]
%timeit np.concatenate([x, x])
%timeit np.stack([x, x])
%timeit x.sum()
%timeit np.sum(x)
%timeit np.mean(x)
%timeit np.sin(x)
%timeit np.unique(x)
%timeit np.broadcast_to(x, (1, 10))
%timeit np.transpose(x)
%timeit np.moveaxis(x, 0, -1)
The slowest run took 16.74 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 365 ns per loop
The slowest run took 44.88 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 242 ns per loop
The slowest run took 59.54 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 938 ns per loop
The slowest run took 6.92 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 5.39 µs per loop
The slowest run took 34.62 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 1.68 µs per loop
The slowest run took 11.59 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 2.72 µs per loop
The slowest run took 8.03 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.06 µs per loop
The slowest run took 50.44 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 1.22 µs per loop
The slowest run took 9.14 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.52 µs per loop
The slowest run took 6.86 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 4.58 µs per loop
The slowest run took 24.69 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 644 ns per loop
The slowest run took 7.18 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 5.3 µs per loop
In [0]:
def dummy_try_array_function_override(
    func, relevant_arguments, args, kwargs):
  return False, None

def dummy_dispatch(dispatcher):
    def decorator(func):
        @functools.wraps(func)
        def new_func(*args, **kwargs):
            relevant_arguments = dispatcher(*args, **kwargs)
            success, value = dummy_try_array_function_override(
                new_func, relevant_arguments, args, kwargs)
            if success:
                return value
            return func(*args, **kwargs)
        return new_func
    return decorator
  
def f(x):
    pass

def _dispatcher(x):
    return (x,)
  
@dummy_dispatch(_dispatcher)
def g(x):
    pass
In [13]:
%timeit f(1)
The slowest run took 13.60 times longer than the fastest. This could mean that an intermediate result is being cached.
10000000 loops, best of 3: 84.8 ns per loop
In [14]:
%timeit g(1)
The slowest run took 9.63 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 567 ns per loop

Micro benchmark conclusions:

  • Adding overloads with try_array_function_override adds about 2-3 us of overload per function call.
    • This is fine for functions written in Python (e.g., np.broadcast_to), but could be significant for functions written in C (e.g., np.concatenate).
    • It's unclear how bad performance degradation would be if we wrote this in C.
  • The explicit decorator dispatch_with is really clean and just as fast as calling try_array_function_override directly.
    • The only downside is that the use of functools.wraps means that decorated functions lose an inspectable signature on Python 2. But this is probably worth it, given how soon NumPy will be Python 3 only.