Numba 0.46.0 Release Demo for Library Developers

This Notebook contains demonstrations of new features in the 0.46 release of Numba that are intended for use by library developers/compiler engineers.

🚨🐉🚨 These are advanced features, dragons be 'ere! 🚨🐉🚨

Features demonstrated in this notebook include:

Other new features present but not demonstrated here include:

First, import the necessary...

In [ ]:
from numba import jit, njit, config, __version__, errors
from numba.extending import overload
import numpy as np
assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 46)
config.SHOW_HELP = False # switch off help messages

Inlining

Numba gains a lot from LLVM itself being able to inline functions, and Numba's internals are geared towards making it easy. However, numerous use cases have arisen where it would be useful to be able to inline a function at the Numba IR level. Numba 0.46 adds support for doing this via the keyword argument inline that can be supplied to the numba.jit family of decorators and also numba.extending.overload, documentation is here.

A motivating use case, the following function obviously can be compiled without issue:

In [ ]:
from numba.typed import List

@njit
def foo():
    l = List()
    for i in range(10):
        l.append(i * 123.45)
    return l

foo()

This minor variation on the above cannot be compiled, the type of the List() in bar cannot be inferred as type inference cannot "see" across the function call into baz where it becomes apparent the type must be a ListType[float64].

In [ ]:
@njit
def baz(l):
    for i in range(10):
        l.append(i * 123.45)

@njit
def bar():
    l = List()
    baz(l)
    return l

try:
    bar()
except errors.TypingError as e:
    print(e)

Something similar to the above use case was the exact reason the ability to perform inlining was explored. The following demonstrates how to resolve the above situation, supplying the kwarg inline='always' to the called function will force it's body to be inlined at the call site in the caller, hence there's now no type inference issue.

In [ ]:
@njit(inline='always')
def baz(l):
    for i in range(10):
        l.append(i * 123.45)

@njit
def bar():
    l = List()
    baz(l)
    return l

bar() # works fine

# baz got inlined, bar was effectively seen as:
# def bar():
#     l = List()
#     for i in range(10):
#         l.append(i * 123.45)
#     return l
#
# which is the same as foo above

Inlining options

To make the inlining capability as flexible as possible three options were added for the kwarg:

  • 'never' - never inline (default)
  • 'always' - always inline
  • a callable - returns True to inline, False to not inline

An example using all of the above follows (it also uses the new environment variable/config option DEBUG_PRINT_AFTER to show the IR, docs are here):

In [ ]:
from numba import njit, ir
import numba

# enable printing of the IR post legalization, i.e. just before it is lowered
numba.config.DEBUG_PRINT_AFTER="ir_legalization"


@njit(inline='never')
def never_inline():
    return 100


@njit(inline='always')
def always_inline():
    return 200


def sentinel_cost_model(expr, caller_info, callee_info):
    # this cost model will return True (i.e. do inlining) if either:
    # a) the callee IR contains an `ir.Const(37)`
    # b) the caller IR contains an `ir.Const(13)` logically prior to the call
    #    site

    # check the callee
    for blk in callee_info.blocks.values():
        for stmt in blk.body:
            if isinstance(stmt, ir.Assign):
                if isinstance(stmt.value, ir.Const):
                    if stmt.value.value == 37:
                        return True

    # check the caller
    before_expr = True
    for blk in caller_info.blocks.values():
        for stmt in blk.body:
            if isinstance(stmt, ir.Assign):
                if isinstance(stmt.value, ir.Expr):
                    if stmt.value == expr:
                        before_expr = False
                if isinstance(stmt.value, ir.Const):
                    if stmt.value.value == 13:
                        return True & before_expr
    return False


@njit(inline=sentinel_cost_model)
def maybe_inline1():
    # Will not inline based on the callee IR with the declared cost model
    # The following is ir.Const(300).
    return 300


@njit(inline=sentinel_cost_model)
def maybe_inline2():
    # Will inline based on the callee IR with the declared cost model
    # The following is ir.Const(37).
    return 37


@njit
def foo():
    a = never_inline()  # will never inline
    b = always_inline()  # will always inline

    # will not inline as the function does not contain a magic constant known to
    # the cost model, and the IR up to the call site does not contain a magic
    # constant either
    d = maybe_inline1()

    # declare this magic constant to trigger inlining of maybe_inline1 in a
    # subsequent call
    magic_const = 13

    # will inline due to above constant declaration
    e = maybe_inline1()

    # will inline as the maybe_inline2 function contains a magic constant known
    # to the cost model
    c = maybe_inline2()

    return a + b + c + d + e + magic_const


foo()

Note in the above IR, as dead code elimination is not performed by default, there are superfluous statements present.

Further, the same inline kwarg is implemented for the numba.extending.overload decorator, documentation and examples are here.

In [ ]:
numba.config.DEBUG_PRINT_AFTER="" # disable debug print again

Customising the compiler

In Numba 0.46 the main compiler pipeline was significantly reworked to make it more easily extendable and to permit users to essentially build their own custom compiler frontends. This change is based on a design similar to that found in LLVM. Full documentation is here.

Changing the default compiler

For a large number of releases the Numba @jit family of decorators have permitted the definition of a custom compiler pipeline via the kwarg pipeline_class, this has not changed, however the type of the class passed as the value has. Numba 0.46 now requires an instance of a numba.compiler.CompilerBase class to be passed as the value, this is a much more flexible class than the before mentioned pipeline.

The default compiler used by Numba is the numba.compiler.Compiler class and it itself makes use of pre-canned pipelines defined in numba.compiler.DefaultPassBuilder by the methods:

  • .define_nopython_pipeline() for the nopython mode pipeline
  • .define_objectmode_pipeline() for the object-mode pipeline
  • .define_interpreted_pipeline() for the interpreted pipeline

Creating a new custom compiler requires extending from the numba.compiler.CompilerBase class and overriding the .define_pipelines() method. e.g.

In [ ]:
from numba.compiler import CompilerBase, DefaultPassBuilder
class CustomCompiler(CompilerBase): # custom compiler extends from CompilerBase

    def define_pipelines(self):
        # define a new set of pipelines (just one in this case) and for demonstration purposes
        # reuse an existing pipeline from the DefaultPassBuilder, namely the "nopython" pipeline
        pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
        # return as an iterable, any number of pipelines may be defined!
        return [pm]

Using the custom compiler is just a question of supplying it via the aforementioned pipeline_class kwarg, for example:

In [ ]:
@jit(pipeline_class=CustomCompiler)
def foo(x):
    return x + 1

foo(10)

The next example won't work with the CustomCompiler because there's only the nopython mode pipeline available in the CustomCompiler and this function contains a Python object.

In [ ]:
@jit(pipeline_class=CustomCompiler)
def foo(x):
    return x + 1, object()

from numba import errors
try:
    foo(10)
except errors.TypingError as e:
    print(str(e))

Implementing a new pipeline

Numba has a large number of pre-defined passes for use, they are categorised as being:

  • untyped, i.e. do not require type information, these are found in numba.untyped_passes
  • typed, i.e. require type information, these are found in numba.typed_passes
  • object mode, i.e. require object mode, these are found in numba.object_mode_passes

For reference, these are the ones in the code base for 0.46.

In [ ]:
for x in numba.compiler_machinery._pass_registry._registry.keys():
    print(x)

Let's implement a new pipeline that:

  • analyses the bytecode
  • rewrites semantic constants
  • does dead branch pruning
  • runs type inference
  • does dead code elimination
  • runs legalisation checks on the IR
  • lowers the IR to machine code

and use it in a new custom compiler. The pipeline management code is found in numba.compiler_machinery

In [ ]:
from numba.compiler_machinery import PassManager

from numba.untyped_passes import (TranslateByteCode, FixupArgs, IRProcessing, DeadBranchPrune,
                                  RewriteSemanticConstants)

from numba.typed_passes import (NopythonTypeInference, DeadCodeElimination, IRLegalization,
                                NoPythonBackend)


def gen_pipeline():
    """ pipeline generation function, it need not be a function, pipelines are often
    defined directly in `ClassExtendingCompilerBase.define_pipelines` but it'll be used
    in a later example for another purpose.
    """
    # create a new PassManager to handle the passes for the pipeline
    pm = PassManager("custom_pipeline")
    
    # untyped
    pm.add_pass(TranslateByteCode, "analyzing bytecode")
    pm.add_pass(IRProcessing, "processing IR")
    pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
    pm.add_pass(DeadBranchPrune, "dead branch pruning")
    
    # typed
    pm.add_pass(NopythonTypeInference, "nopython frontend")
    pm.add_pass(DeadCodeElimination, "DCE")

    # legalise
    pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")

    # lower
    pm.add_pass(NoPythonBackend, "nopython mode backend")

    # finalise the contents
    pm.finalize()
    return pm


class NewPipelineCompiler(CompilerBase): 

    def define_pipelines(self):
        return [gen_pipeline()]

Now use the NewPipelineCompiler in a deliberately contrived example to demonstrate the effect of certain passes.

In [ ]:
numba.config.DEBUG_PRINT_AFTER="ir_processing,rewrite_semantic_constants,dead_branch_prune,dead_code_elimination"


@jit(pipeline_class=NewPipelineCompiler)
def foo(arr):
    if arr.ndim == 1:
        return 100
    else:
        return 200

x = np.arange(10) # 1d array input, x.ndim = 1
foo(x)

In the output above, the following can be seen:

  • The ir_processing pass produces the inital IR.
  • The rewrite_semantic_constants pass replaces the expression:
    • $0.2 = getattr(value=arr, attr=ndim) with $0.2 = const(int, 1)
  • The dead_branch_prune pass spotted that the block with label 14 is dead and removed it because:
     $0.2 = const(int, 1)                     ['$0.2']
     del arr                                  []
     $const0.3 = const(int, 1)                ['$const0.3']
     $0.4 = $0.2 == $const0.3
    evaluates to $0.4 always being True and as a result, it's use as the predicate in branch $0.4, 10, 14 means the 10 branch will always be taken, 14 is dead.
  • The dead_code_elimination pass removed all the statements which were dead (had no effect).

In the final output there are now two blocks, labels 0 and 10. Block 0 has only one statement, an unconditional jump to 10. In the next section a new pass is going to be written to simplify the control flow graph in such situations, as it's clear that the blocks can be fused.

Implementing a new compiler pass

Implementing a new compiler pass involves writing a class that inherits from numba.compiler_machinery.CompilerPass. It must be registered with the pass registry before use and through the process of registration declare some information about what it will do in certain scenarios. Documentation for this feature is here.

Continuing with the above example, Numba has a function numba.ir_utils.simplify_CFG which does the control flow graph simplification alluded to in the final paragraph above. In the following this function is wrapped in a compiler pass and then used in a new pipeline.

In [ ]:
from numba.ir_utils import simplify_CFG
from numba.compiler_machinery import register_pass, FunctionPass

# Register this pass with the compiler framework, declare that it can mutate the control
# flow graph and that it is not an analysis_only pass (it potentially mutates the IR).
@register_pass(mutates_CFG=True, analysis_only=False)

# Inherit from FunctionPass, the base class for passes operating on functions
class SimplifyCFG(FunctionPass):
    _name = "simplify_cfg" # the common name for the pass

    def __init__(self):
        FunctionPass.__init__(self)
        
    # implement the method to do the work, "state" is the internal compiler
    # state from the CompilerBase instance.
    def run_pass(self, state):
        # get the IR blocks
        blks = state.func_ir.blocks
        # run the simplification
        new_blks = simplify_CFG(blks)
        # update the reference to the block state
        state.func_ir.blocks = new_blks
        
        # return whether the IR was mutated (here, CFG change implies IR change)
        mutated = blks != new_blks
        return mutated


# define a new compiler
class NewPipelineWSimplifyCFGCompiler(CompilerBase):

    def define_pipelines(self):
        # generate the same pipeline as in the previous example
        pm = gen_pipeline()
        
        # add the new pass after DeadCodeElimination
        pm.add_pass_after(SimplifyCFG, DeadCodeElimination)
        
        # re-finalize the pipeline since the above mutated it
        pm.finalize()
        return [pm]

Now re-run the foo function again with the updated custom compiler including the new pass in its pipeline. Also, print the IR after dead code elimination (the end of output from the last example) and now after the new SimplifyCFG pass.

In [ ]:
numba.config.DEBUG_PRINT_AFTER="dead_code_elimination,simplify_cfg"

@jit(pipeline_class=NewPipelineWSimplifyCFGCompiler)
def foo(arr):
    if arr.ndim == 1:
        return 100
    else:
        return 200

x = np.arange(10) # 1d array input, x.ndim = 1
foo(x)

It can be seen in the above that the CFG has been simplified after the new simplify_cfg pass has run, the IR is now a single block.