Numba 0.49.0 Release Demo

This notebook contains a demonstration of new features present in the 0.49.0 release of Numba. Whilst release notes are produced as part of the CHANGE_LOG, there's nothing like seeing code in action! This release contains some significant changes to Numba internals and, of course, some exciting new features!

Important updates/information about this release:

  • This release drops support for Python 2 both for users and in the code base itself. It also raises the minimum supported versions of related software as follows:

    • Python >= 3.6
    • NumPy >= 1.15
    • SciPy >=1.0.

    It's still possible to build with NumPy 1.11 but runtime support is for 1.15 or later.

  • A huge amount of refactoring happened in this release (mainly module movement) to try and clean up the Numba code base. This refactoring was done so as to make it easier for users to contribute to the project, for core developers to maintain the project, and to remove legacy code. The core developers had been waiting for an opportunity to do this for years and decided that coinciding with Python 2 retirement was best as it would lead to least disruption to users.
  • As a result of the above, projects that relied on Numba's internals may have to adjust their imports. There is however an import "shim" in place in 0.49 that tries to faithfully replicate the original import locations. If one of these shim locations is used, it will issue warnings about the refactoring and state the new import location. This is so that projects relying on Numba's internals have a couple of months to make changes.

The core developers would like to offer their thanks to all users for their understanding and support with respect to these changes. If you need help migrating your code base to 0.49 due to this refactoring, try one of:

New features:

The new features are split into sections based on use case...

For all users:

For CUDA target users:

For Numba extension writers/expert users:

First, import the necessary from Numba and NumPy...

In [ ]:
from numba import jit, njit, config, __version__, prange
from numba.typed import List
config.NUMBA_NUM_THREADS = 4 # for this demo, pretend there's 4 cores on the machine
from numba.extending import overload
import numba
import numpy as np
assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 49)

For all users...

Thread masking

Numerous users have asked for the ability to dynamically control, at runtime, the number of threads Numba uses in parallel regions. Numba 0.49 brings this functionality, it is modelled after OpenMP as this is a model familiar to a lot of users. Documentation is here.

The API consists of two functions:

  • numba.get_num_threads() - returns the number of threads currently in use.
  • numba.set_num_threads(nthreads) - sets the number of threads to use to nthreads.

these functions themselves are thread and fork safe and are available to call from both Python and JIT compiled code!

For those interested, the implementation details are here, as a warning, they are somewhat gnarly!

Now, a demonstration:

In [ ]:
from numba import get_num_threads, set_num_threads

# Discover thread mask from Python
print("Number of threads: {}".format(get_num_threads()))

# Set thread mask from Python
set_num_threads(2)

# Check it was set
print("Number of threads: {}".format(get_num_threads()))

@njit
def get_mask():
    print("JIT code, number of threads", get_num_threads())

# Discover thread mask from JIT code
get_mask()

@njit
def set_mask(x):
    set_num_threads(x)
    print("JIT code, number of threads", get_num_threads())

# Set thread mask from JIT code
set_mask(3)

Something more complicated, limiting threads in use:

In [ ]:
@njit(parallel=True)
def thread_limiting():
    n = 5
    mask1 = 3
    mask2 = 2
    
    # np.zeros is parallelised, all threads are in use here
    A = np.zeros((n, mask1))
    
    # only use mask1 threads in this parallel region
    set_num_threads(mask1)
    for i in prange(mask1):
        A[:, i] = i

    # only use mask2 threads in this parallel region
    set_num_threads(mask2)
    A[:, :] = np.sqrt(A)

    return A

print(thread_limiting())

# Uncomment and run this to see the parallel diagnostics for the function above
# thread_limiting.parallel_diagnostics(thread_limiting.signatures[0], level=3)

It should be noted that once in a parallel region, setting the number of threads has no effect on the region that is executing, it does however impact subsequent parallel region launches. For example:

In [ ]:
mask = config.NUMBA_NUM_THREADS - 1 # create a mask

# some constants based on mask size
N = config.NUMBA_NUM_THREADS
M = 2 * config.NUMBA_NUM_THREADS

@njit(parallel=True)
def child_func(buf, fid):
    M, N = buf.shape
    for i in prange(N): # parallel write into the row slice
        buf[fid, i] = get_num_threads()

@njit(parallel=True)
def parent_func(nthreads):
    acc = 0
    buf = np.zeros((M, N))
    print("Parent: Setting mask to:", nthreads)
    set_num_threads(nthreads) # set threads to mask
    print("Parent: Running parallel loop of size", M)
    for i in prange(M):
        local_mask = 1 + i % mask
        
        # set threads in parent function
        set_num_threads(local_mask)
        
        # only call child_func if your thread mask permits!
        if local_mask < N:
            child_func(buf, local_mask)

        # add up all used threadmasks
        print("prange index", i, ". get_num_threads()", get_num_threads())
        acc += get_num_threads()
    return acc, buf

print("Calling with mask: {} and constants M = {}, N = {}".format(mask, M, N))
got_acc, got_buf = parent_func(mask)
print("got acc = {}".format(got_acc))
# expect sum of local_masks in prange(M) loop
print("expect acc = {}".format(np.sum(1 + np.arange(M) % mask)))
# Output `buf` should only be written to in rows with index < N as
# the thread mask would forbid it, the contents of the rows is the thread mask
print(got_buf)

First-class function types

For quite some time Numba has been able to pass around Numba JIT decorated functions as objects, these, however, have been seen by Numba as different types even if they have identical signatures. Numba 0.49.0 brings a new experimental feature that makes function objects first class types such that functions with the same signatures can be see has being "of the same type" for the purposes of type inference. Further cfuncs, JIT functions and a new "Wrapper address protocol" based functions are all supported to some degree. Documentation is here.

An example:

In [ ]:
@njit("intp(intp)")
def foo(x):
    return x + 1

@njit("intp(intp)")
def bar(x):
    return x + 2

@njit("intp(intp)")
def baz(x):
    return x + 3

@njit
def apply(arg, *functions):
    for fn in functions: # to iterate over a container it must contain "all the same types"
        arg = fn(arg)
    return arg

apply(10, foo, bar, baz)

Typed list update

Numba's typed.List container has been enhanced with the ability to construct a new instance directly from an iterable, this saving a lot of boiler plate code. A quick demonstration:

In [ ]:
from numba.typed import List

print(List(range(10)))

x = [4., 6., 2., 1.]
print(List(x))

# also works in JIT code
@njit
def list_ctor(x):
    return List(x), List((1, 2, 3, 4))

list_ctor(np.arange(10.))

Support for ord and chr

For users wanting to encode/decode strings, particularly those of the ASCII variety, ord and chr are now supported:

In [ ]:
@njit
def demo_ord_chr():
    alphabet = 'abcdefghijklmnopqrstuvwxyz'
    lord = List()
    lchr = List()
    for idx, char in enumerate(alphabet, ord('A')):
        lord.append(ord(char))
        lchr.append(chr(idx))
    return lord, lchr

demo_ord_chr()

Checking if a function is JIT wrapped

A common question from writers of extension library that can consume Numba functions is "How do I know if a function my application receives as an argument is already Numba JIT wrapped?". Numba 0.49 answers this with the numba.extending.is_jitted function:

In [ ]:
def some_func(x):
    return x + 1

def consumer(func, *args):
    if not numba.extending.is_jitted(func):
        print("Not JIT wrapped, will wrap and compile!")
        func = njit(func)
    return func(*args)

consumer(some_func, 10)

Newly supported NumPy functions/features

This release contains support for direct iteration over np.ndarrays and one newly supported NumPy function, np.isnat, all written by contributors from the Numba community:

A quick demo of the above:

In [ ]:
NAT = np.datetime64('NaT')
dt = np.dtype('<M8')

@njit
def demo_numpy():
    a = np.empty((5, 3, 2), dt)
    out = np.zeros_like(a, np.bool_)
    # iterate with ndindex
    for x in np.ndindex(a.shape):
        if np.random.random() < 0.5:
            a[x] = NAT
            
    count = 0
    # now iterate directly
    for twoDarr in a:
        for oneDarr in twoDarr:
            for item in oneDarr:
                if np.isnat(item):
                    count += 1
    
    # use ufunc
    ufunc_count = np.isnat(a).sum()
    
    assert count == ufunc_count
    

demo_numpy()

Using tuples in parallel regions

Due to long standing issues in the internal implementation of parallel regions (that they are based on Generalized Universal Functions), functions with parallel=True have not supported tuple "arguments" to these regions. This is a bit of a technical detail, but is now fixed, so common things like expressing a loop nest iteration limits from an array shape works as expected.

In [ ]:
@njit(parallel=True)
def demo_tuple_in_prange(A):
    for i in prange(A.shape[0]):
        for j in range(A.shape[1]):
            for k in range(A.shape[2]):
                A[i, j, k] = i + j + k

x = 4
y = 3
z = 2
A = np.empty((x, y, z))
demo_tuple_in_prange(A)
print(A)

For CUDA target users...

Prior to Numba 0.49, if a user forgot to specify a launch configuration to a CUDA kernel a default configuration of one thread and one block was used. This lead to hard to explain behaviours for example, code that worked by virtue of running in this minimum configuration, or code that exhibited strange performance characteristics.

All kernels require launch configurations

As a result, in Numba 0.49, it is now a requirement for all CUDA kernel launches to be explicitly configured in both the CUDA simulator and on real hardware. Example:

In [ ]:
config.ENABLE_CUDASIM = 1
from numba import cuda

@cuda.jit
def kernel(x):
    print("In the kernel", cuda.threadIdx)

# bad launch, no configuration given
try:
    kernel(np.arange(10))
except ValueError as e:
    print(e)
    
# good launch, configuration specified
kernel[2, 4](np.arange(10))

External Memory Management (EMM) Plugin interface

Whilst not possible to demonstrate this feature in the current notebook, Numba 0.49 gains an External Memory Management (EMM) Plugin interface. When multiple CUDA-aware libraries are used together, it may be preferable for Numba to defer to another library for memory management. The EMM Plugin interface facilitates this, by enabling Numba to use another CUDA-aware library for all allocations and deallocations. Documentation for this feature is here.

For developers of Numba extensions...

There's three changes that may be of interest to those working on Numba extensions or with Numba IR:

  1. Numba transforms it's IR to SSA.
  2. Debug dumps now have syntax highlighting.
  3. Disassembly CFGs are now available (not demonstrated here, see documentation).

Static Single Assignment form

Numba 0.49 contains the start of an important change to Numba's internal representation (IR). The change is essentially that the IR is now coerced into static single assignment (SSA) form immediately prior to when type inference is performed. This fixes a number of bugs and makes it considerably easier to write more advanced optimisation passes. It's hoped that SSA form can be extended further up the compilation pipeline as time allows.

A quick demonstration that shows SSA form and the new syntax highlighted dumps in action:

In [ ]:
config.COLOR_SCHEME = 'light_bg' # colour scheme highlighting for a light background
config.HIGHLIGHT_DUMPS = '1' # request dump highlighting 
config.DEBUG_PRINT_WRAP = 'reconstruct_ssa' # print IR both sides of the SSA reconstruction pass

@njit
def demo_ssa(x):
    if x > 2:
        a = 12
    elif x > 4:
        a = 20
    else:
        a = 3
    return a

print(demo_ssa(5))

# switch it off again!
config.DEBUG_PRINT_WRAP = ''