Numba 0.47.0 Release Demo

This notebook contains a demonstration of new features present in the 0.47.0 release of Numba. Whilst release notes are produced as part of the CHANGE_LOG, there's nothing like seeing code in action! This release contains a large number of exciting new features!

Demonstrations of new features include:

First, import the necessary from Numba and NumPy...

In [ ]:
from numba import jit, njit, config, __version__, errors
from numba.errors import NumbaPendingDeprecationWarning
import warnings
# we're going to ignore a couple of deprecation warnings
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
from numba.extending import overload
config.SHOW_HELP = 0
import numba
import numpy as np
assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 47)

Bounds checking

The long awaited support for bounds checking has been added in this release, the associated documentation is available here. Here's a demonstration:

In [ ]:
config.FULL_TRACEBACKS = 1
@njit(boundscheck=True)
def OOB_access(x):
    sz = len(x)
    a = x[0] # fine, first element of x
    a += x[sz - 1] # fine, last element of x
    a += x[sz] # oops, out of bounds!

try:
    OOB_access(np.ones(10))
except IndexError as e:
    print(type(e), e)

The setting of config.FULL_TRACEBACKS (environment variable equivalent) forces the printing of the index, axis and dimension size to the terminal (assuming a terminal was used to invoke python). For example, the terminal that launched this notebook now has:

debug: IndexError: index 10 is out of bounds for axis 0 with size 10

on it. A future release will enhance this feature to include the out of bounds access information in the error message.

Dynamic function definition

The 0.47.0 release adds the following new capability to Numba: dynamic function generation. Essentially functions (closures) defined in a JIT decorated function can now "escape" the function they are defined in and be used as arguments in subsequent function calls. For example:

In [ ]:
# takes a function and calls it with argument arg, multiplies the result by 7
@njit
def consumer(function, arg):
    return function(arg) * 7

_GLOBAL = 5

@njit
def generator_func():
    _FREEVAR = 10

    def escapee(x): # closure, 'a' is a local, '_FREEVAR' is a freevar, '_GLOBAL' is global
        a = 9
        return x * _FREEVAR + a * _GLOBAL

    # data argument for the consumer call
    x = np.arange(5)

    # escapee function is passed to the consumer function along with its argument
    return consumer(escapee, x)

generator_func()
    

Support for map, filter, reduce

The ability to create dynamic functions lead to being able to write support for map, filter and reduce. This makes it possible to write more "pythonic" code in Numba :-)

In [ ]:
import operator
from functools import reduce
from numba.typed import List

@njit
def demo_map_filter_reduce():

    # This will be used in map
    def mul_n(x, multiplier):
        return x * multiplier
    
    # This will be used in filter
    V = 20
    def greater_than_V(x):
        return x > V # captures V from freevars
    
    # this will be used in reduce
    reduce_lambda = lambda x, y: (x * 2) + y

    a = [x ** 2 for x in range(10)]    
    n = len(a)           
    return reduce(reduce_lambda, filter(greater_than_V, map(mul_n, a, range(n))))

demo_map_filter_reduce()

Support for list.sort()/sorted with key

A further extension born from the ability to create dynamic functions was being able to support the key argument to list.sort and sorted, a quick demonstration:

In [ ]:
@njit
def demo_sort_sorted(chars):

    def key(x):
        return x.upper()
          
    x = chars[:]
    x.sort()
    print("sorted:", ''.join(x))

    x = chars[:]
    x.sort(reverse=True)
    print("sorted backwards:", ''.join(x))

    x = chars[:]
    x.sort(key=key)
    print("sorted key=x.upper():", ''.join(x))
    
    print("sorted(), reversed", ''.join(sorted(x, reverse=True)))
    
    def numba_order(x):
        return 'NUMBA🐍numba⚡'.index(x)
    
    x = chars[:]
    x.sort(key=numba_order)
    print("sorted key=numba_order:", ''.join(x))
    
# let's sort a list of characters
input_list = ['m','M','a','N','n','u','⚡','🐍','B','b','U','A']
demo_sort_sorted(input_list)

Initial support for basic try/except

Numba 0.47.0 has some basic support for the use of try/except in JIT compiled functions. This is a long awaited feature that has been requested many times. Support is limited at present to two use cases docs.

In [ ]:
@njit
def demo_try_bare_except(a, b):

    try:
        c = a / b
        return c
    except:
        print("caught exception")
        return -1
    
print("ok input:", demo_try_bare_except(5., 10.))
print("div by zero input:", demo_try_bare_except(5, 0))

The class Exception can also be caught, let's mix this with the new bounds checking support:

In [ ]:
@njit(boundscheck=True)
def demo_try_except_exception(array, index):

    try:
        return array[index]
    except Exception:
        print("caught exception")
        return -1
    
x = np.ones(5)
print("ok input:", demo_try_except_exception(x, 0))
print("OOB access:", demo_try_except_exception(x, 10))

User defined exception classes also work:

In [ ]:
class UserDefinedException(Exception):
    def __init__(self, some_arg):
        self._some_arg = some_arg    

@njit(boundscheck=True)
def demo_try_except_ude():

    try:
        raise UserDefinedException(123)
    except Exception:
        return "caught UDE!"
    
print(demo_try_except_ude())

Iterating over mixed type containers

As users of Numba are very aware, Numba has to be able to work out the type of all the variables in a function to be able to compile it (function must be statically typable!). Prior to Numba 0.47.0 tuples of heterogeneous type could not be iterated over as the type of the induction variable in a loop could not be statically computed and further the loop body contents would have a different set of types of each type in the tuple. For example, this doesn't work:

In [ ]:
from numba import literal_unroll

@njit
def does_not_work():
    tup = (1, 'a', 2j)
    for i in tup:
        print(i) # Numba cannot work out type of `i`, it changes each loop iteration

print("Typing problem")
try:
    does_not_work()
except errors.TypingError as e:
    print(e)

In Numba 0.47.0 a new function, numba.literal_unroll, is introduced. The function itself does nothing much, it's just a token to tell the Numba compiler that the argument needs special treatment for use as an iterable. When this function is applied in situations like in the following, the body of the loop is "versioned" based on the types in the tuple such that Numba can actually statically work out the types for each iteration and compilation will succeed. Here's a working version of the above failing example:

In [ ]:
# use special function `numba.literal_unroll`
@njit
def works():
    tup = (1, 'a', 2j)
    for i in literal_unroll(tup):
        print(i) # literal_unroll tells the compiler to version the loop body based on type.


print("Apply literal_unroll():")
works()

A more involved example might be a tuple of locally defined functions (which are all different types by virtue of the Numba type system) that are iterated over:

In [ ]:
@njit
def fruit_cookbook():
    def get_apples(x):
        return ['apple' for _ in range(x * 3)]
    def get_oranges(x):
        return ['orange' for _ in range(x * 4)]
    def get_bananas(x):
        return ['banana' for _ in range(x * 2)]

    ingredients = (get_apples, get_oranges, get_bananas)
    
    def fruit_salad(scale):
        shopping_list = []
        for ingredient in literal_unroll(ingredients):
            shopping_list.extend(ingredient(scale))
        return shopping_list
    
    print(fruit_salad(2))

fruit_cookbook()

Finally, because Numba has string and integer literal support, it's possible to dispatch on these values at compile time and version the loop body with a value based specialisations:

In [ ]:
from numba import types

# function stub to overload
def dt(value):
    pass

@overload(dt, inline='always')
def ol_dt(li):
    # dispatch based on a string literal
    if isinstance(li, types.StringLiteral):
        value = li.literal_value
        if value == "apple":
            def impl(li):
                return 1
        elif value == "orange":
            def impl(li):
                return 2
        elif value == "banana":
            def impl(li):
                return 3
        return impl

    # dispatch based on an integer literal
    elif isinstance(li, types.IntegerLiteral):
        value = li.literal_value
        if value == 0xca11ab1e:
            def impl(li):
                # close over the dispatcher :)
                return 0x5ca1ab1e + value
            return impl

@njit
def unroll_and_dispatch_on_literal():
    acc = 0
    for t in literal_unroll(('apple', 'orange', 'banana', 0xca11ab1e)):
        acc += dt(t)
    return acc

print(unroll_and_dispatch_on_literal())

It's hoped that in a future version of Numba the token function literal_unroll will not be needed and loop body versioning opportunities will be automatically identified.

Newly supported NumPy functions/features

This release contains a number of newly supported NumPy functions, all written by contributors from the Numba community:

  • np.arange now supports the dtype keyword argument.

  • Also now supported are:

    • np.lcm
    • np.gcd

A quick demo of the above:

In [ ]:
@njit
def demo_numpy():
    a = np.arange(5, dtype=np.uint8)
    b = np.lcm(a, 2)
    c = np.gcd(a, 3)
    
    return a, b, c

demo_numpy()

New unicode string features

A large number of unicode string features/enhancements were added in 0.47.0, namely:

  • str.index()
  • str.rindex()
  • start/end parameters for str.find()
  • str.rpartition()
  • str.lower()

and a lot of querying functions:

  • str.isalnum()
  • str.isalpha()
  • str.isascii()
  • str.isidentifier()
  • str.islower()
  • str.isprintable()
  • str.isspace()
  • str.istitle()
In [ ]:
@njit
def demo_string_enhancements(arg):
    
    print("index:", arg.index("🐍")) # index of snake
    print("rindex:", arg.rindex("🐍")) # rindex of snake
    print("find:", arg.find("🐍", start=2, end=6)) # find snake with start+end
    print("rpartition:", arg.rpartition("🐍")) # rpartition snake
    print("lower:", arg.lower()) # lower snake

    print("isalnum:", 'abc123'.isalnum(), '🐍'.isalnum())
    print("isalpha:", 'abc'.isalpha(), '123'.isalpha())
    print("isascii:", 'abc'.isascii(), '🐍'.isascii())
    print("isidentifier:", '1'.isidentifier(), 'var'.isidentifier())
    print("islower:", 'SHOUT'.islower(), 'whisper'.islower())
    print("isprintable:", '\x07'.isprintable(), 'BEL'.isprintable())
    print("isspace:", ' '.isspace(), '_'.isspace())
    print("istitle:", "Titlestring".istitle(), "notTitlestring".istitle())

    
arg = "N🐍u🐍M🐍b🐍A⚡"
demo_string_enhancements(arg)