This Notebook contains demonstrations of new features in the 0.46 release of Numba that are intended for use by library developers/compiler engineers.
Features demonstrated in this notebook include:
Other new features present but not demonstrated here include:
Cmodules and associated helper functions. Documentation here.
jitdecorator. Documentation here.
First, import the necessary...
from numba import jit, njit, config, __version__, errors from numba.extending import overload import numpy as np assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 46) config.SHOW_HELP = False # switch off help messages
Numba gains a lot from LLVM itself being able to inline functions, and Numba's internals are geared towards making it easy. However, numerous use cases have arisen where it would be useful to be able to inline a function at the Numba IR level. Numba 0.46 adds support for doing this via the keyword argument
inline that can be supplied to the
numba.jit family of decorators and also
numba.extending.overload, documentation is here.
A motivating use case, the following function obviously can be compiled without issue:
from numba.typed import List @njit def foo(): l = List() for i in range(10): l.append(i * 123.45) return l foo()
This minor variation on the above cannot be compiled, the type of the
bar cannot be inferred as type inference cannot "see" across the function call into
baz where it becomes apparent the type must be a
@njit def baz(l): for i in range(10): l.append(i * 123.45) @njit def bar(): l = List() baz(l) return l try: bar() except errors.TypingError as e: print(e)
Something similar to the above use case was the exact reason the ability to perform inlining was explored. The following demonstrates how to resolve the above situation, supplying the kwarg
inline='always' to the called function will force it's body to be inlined at the call site in the caller, hence there's now no type inference issue.
@njit(inline='always') def baz(l): for i in range(10): l.append(i * 123.45) @njit def bar(): l = List() baz(l) return l bar() # works fine # baz got inlined, bar was effectively seen as: # def bar(): # l = List() # for i in range(10): # l.append(i * 123.45) # return l # # which is the same as foo above
To make the inlining capability as flexible as possible three options were added for the kwarg:
'never'- never inline (default)
'always'- always inline
An example using all of the above follows (it also uses the new environment variable/config option
DEBUG_PRINT_AFTER to show the IR, docs are here):
from numba import njit, ir import numba # enable printing of the IR post legalization, i.e. just before it is lowered numba.config.DEBUG_PRINT_AFTER="ir_legalization" @njit(inline='never') def never_inline(): return 100 @njit(inline='always') def always_inline(): return 200 def sentinel_cost_model(expr, caller_info, callee_info): # this cost model will return True (i.e. do inlining) if either: # a) the callee IR contains an `ir.Const(37)` # b) the caller IR contains an `ir.Const(13)` logically prior to the call # site # check the callee for blk in callee_info.blocks.values(): for stmt in blk.body: if isinstance(stmt, ir.Assign): if isinstance(stmt.value, ir.Const): if stmt.value.value == 37: return True # check the caller before_expr = True for blk in caller_info.blocks.values(): for stmt in blk.body: if isinstance(stmt, ir.Assign): if isinstance(stmt.value, ir.Expr): if stmt.value == expr: before_expr = False if isinstance(stmt.value, ir.Const): if stmt.value.value == 13: return True & before_expr return False @njit(inline=sentinel_cost_model) def maybe_inline1(): # Will not inline based on the callee IR with the declared cost model # The following is ir.Const(300). return 300 @njit(inline=sentinel_cost_model) def maybe_inline2(): # Will inline based on the callee IR with the declared cost model # The following is ir.Const(37). return 37 @njit def foo(): a = never_inline() # will never inline b = always_inline() # will always inline # will not inline as the function does not contain a magic constant known to # the cost model, and the IR up to the call site does not contain a magic # constant either d = maybe_inline1() # declare this magic constant to trigger inlining of maybe_inline1 in a # subsequent call magic_const = 13 # will inline due to above constant declaration e = maybe_inline1() # will inline as the maybe_inline2 function contains a magic constant known # to the cost model c = maybe_inline2() return a + b + c + d + e + magic_const foo()
Note in the above IR, as dead code elimination is not performed by default, there are superfluous statements present.
Further, the same
inline kwarg is implemented for the
numba.extending.overload decorator, documentation and examples are here.
numba.config.DEBUG_PRINT_AFTER="" # disable debug print again
In Numba 0.46 the main compiler pipeline was significantly reworked to make it more easily extendable and to permit users to essentially build their own custom compiler frontends. This change is based on a design similar to that found in LLVM. Full documentation is here.
For a large number of releases the Numba
@jit family of decorators have permitted the definition of a custom compiler pipeline via the kwarg
pipeline_class, this has not changed, however the type of the class passed as the value has. Numba 0.46 now requires an instance of a
numba.compiler.CompilerBase class to be passed as the value, this is a much more flexible class than the before mentioned pipeline.
The default compiler used by Numba is the
numba.compiler.Compiler class and it itself makes use of pre-canned pipelines defined in
numba.compiler.DefaultPassBuilder by the methods:
.define_nopython_pipeline()for the nopython mode pipeline
.define_objectmode_pipeline()for the object-mode pipeline
.define_interpreted_pipeline()for the interpreted pipeline
Creating a new custom compiler requires extending from the
numba.compiler.CompilerBase class and overriding the
.define_pipelines() method. e.g.
from numba.compiler import CompilerBase, DefaultPassBuilder class CustomCompiler(CompilerBase): # custom compiler extends from CompilerBase def define_pipelines(self): # define a new set of pipelines (just one in this case) and for demonstration purposes # reuse an existing pipeline from the DefaultPassBuilder, namely the "nopython" pipeline pm = DefaultPassBuilder.define_nopython_pipeline(self.state) # return as an iterable, any number of pipelines may be defined! return [pm]
Using the custom compiler is just a question of supplying it via the aforementioned
pipeline_class kwarg, for example:
@jit(pipeline_class=CustomCompiler) def foo(x): return x + 1 foo(10)
The next example won't work with the
CustomCompiler because there's only the
nopython mode pipeline available in the
CustomCompiler and this function contains a Python object.
@jit(pipeline_class=CustomCompiler) def foo(x): return x + 1, object() from numba import errors try: foo(10) except errors.TypingError as e: print(str(e))
Numba has a large number of pre-defined passes for use, they are categorised as being:
untyped, i.e. do not require type information, these are found in
typed, i.e. require type information, these are found in
object mode, i.e. require object mode, these are found in
For reference, these are the ones in the code base for 0.46.
for x in numba.compiler_machinery._pass_registry._registry.keys(): print(x)
Let's implement a new pipeline that:
and use it in a new custom compiler. The pipeline management code is found in
from numba.compiler_machinery import PassManager from numba.untyped_passes import (TranslateByteCode, FixupArgs, IRProcessing, DeadBranchPrune, RewriteSemanticConstants) from numba.typed_passes import (NopythonTypeInference, DeadCodeElimination, IRLegalization, NoPythonBackend) def gen_pipeline(): """ pipeline generation function, it need not be a function, pipelines are often defined directly in `ClassExtendingCompilerBase.define_pipelines` but it'll be used in a later example for another purpose. """ # create a new PassManager to handle the passes for the pipeline pm = PassManager("custom_pipeline") # untyped pm.add_pass(TranslateByteCode, "analyzing bytecode") pm.add_pass(IRProcessing, "processing IR") pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants") pm.add_pass(DeadBranchPrune, "dead branch pruning") # typed pm.add_pass(NopythonTypeInference, "nopython frontend") pm.add_pass(DeadCodeElimination, "DCE") # legalise pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering") # lower pm.add_pass(NoPythonBackend, "nopython mode backend") # finalise the contents pm.finalize() return pm class NewPipelineCompiler(CompilerBase): def define_pipelines(self): return [gen_pipeline()]
Now use the
NewPipelineCompiler in a deliberately contrived example to demonstrate the effect of certain passes.
numba.config.DEBUG_PRINT_AFTER="ir_processing,rewrite_semantic_constants,dead_branch_prune,dead_code_elimination" @jit(pipeline_class=NewPipelineCompiler) def foo(arr): if arr.ndim == 1: return 100 else: return 200 x = np.arange(10) # 1d array input, x.ndim = 1 foo(x)
In the output above, the following can be seen:
ir_processingpass produces the inital IR.
rewrite_semantic_constantspass replaces the expression:
$0.2 = getattr(value=arr, attr=ndim)with
$0.2 = const(int, 1)
dead_branch_prunepass spotted that the block with
label 14is dead and removed it because:
$0.2 = const(int, 1) ['$0.2'] del arr  $const0.3 = const(int, 1) ['$const0.3'] $0.4 = $0.2 == $const0.3
Trueand as a result, it's use as the predicate in
branch $0.4, 10, 14means the
10branch will always be taken,
dead_code_eliminationpass removed all the statements which were dead (had no effect).
In the final output there are now two blocks, labels
0 has only one statement, an unconditional jump to
10. In the next section a new pass is going to be written to simplify the control flow graph in such situations, as it's clear that the blocks can be fused.
Implementing a new compiler pass involves writing a class that inherits from
numba.compiler_machinery.CompilerPass. It must be registered with the pass registry before use and through the process of registration declare some information about what it will do in certain scenarios. Documentation for this feature is here.
Continuing with the above example, Numba has a function
numba.ir_utils.simplify_CFG which does the control flow graph simplification alluded to in the final paragraph above. In the following this function is wrapped in a compiler pass and then used in a new pipeline.
from numba.ir_utils import simplify_CFG from numba.compiler_machinery import register_pass, FunctionPass # Register this pass with the compiler framework, declare that it can mutate the control # flow graph and that it is not an analysis_only pass (it potentially mutates the IR). @register_pass(mutates_CFG=True, analysis_only=False) # Inherit from FunctionPass, the base class for passes operating on functions class SimplifyCFG(FunctionPass): _name = "simplify_cfg" # the common name for the pass def __init__(self): FunctionPass.__init__(self) # implement the method to do the work, "state" is the internal compiler # state from the CompilerBase instance. def run_pass(self, state): # get the IR blocks blks = state.func_ir.blocks # run the simplification new_blks = simplify_CFG(blks) # update the reference to the block state state.func_ir.blocks = new_blks # return whether the IR was mutated (here, CFG change implies IR change) mutated = blks != new_blks return mutated # define a new compiler class NewPipelineWSimplifyCFGCompiler(CompilerBase): def define_pipelines(self): # generate the same pipeline as in the previous example pm = gen_pipeline() # add the new pass after DeadCodeElimination pm.add_pass_after(SimplifyCFG, DeadCodeElimination) # re-finalize the pipeline since the above mutated it pm.finalize() return [pm]
Now re-run the
foo function again with the updated custom compiler including the new pass in its pipeline. Also, print the IR after dead code elimination (the end of output from the last example) and now after the new
numba.config.DEBUG_PRINT_AFTER="dead_code_elimination,simplify_cfg" @jit(pipeline_class=NewPipelineWSimplifyCFGCompiler) def foo(arr): if arr.ndim == 1: return 100 else: return 200 x = np.arange(10) # 1d array input, x.ndim = 1 foo(x)
It can be seen in the above that the CFG has been simplified after the new
simplify_cfg pass has run, the IR is now a single block.