smartjit @jit
decorator accepts the same set of argument as @numba.jit
or @numba.njit
with the addition of two new keyword arguments:
use_jit: Callback
- Callback function which returns an
smart_jit.Action
, determining whether to use jit compilation.Action.INTERPRETER
will cause the function to always be interpreted, whileAction.JIT
will cause the function to always be JITted. If a callback function is passed, it will be evaluated on each function call, and the result will determine whether that call should be jitted or interpreted.
- Callback function which returns an
warn_on_fallback: bool
- Enabling this option will trigger a warning when JIT compilation/execution fails to utilize the JIT compiler and instead defaults to using the interpreter. This feature can be useful for debugging purposes. Default is
False
.
- Enabling this option will trigger a warning when JIT compilation/execution fails to utilize the JIT compiler and instead defaults to using the interpreter. This feature can be useful for debugging purposes. Default is
We also implement an Enum, named Action
, which contains the set of possible actions one can return from use_jit
callable:
Action.INTERPRETER
: Fallback execution to the interpreterAction.JIT
: JIT compile and executeAction.RAISE_EXCEPTION
: Raise no matchTypeError
from smart_jit import jit, Action
import numpy as np
def use_jit_sum_fast(A):
# use jit compilation when length of A is greater than 100_000
if len(A) > 100_000:
return Action.JIT
return Action.INTERPRETER
@jit(fastmath=True, use_jit=use_jit_sum_fast, warn_on_fallback=True)
def sum_fast(A):
acc = 0.0
# with fastmath, the reduction can be vectorized as floating point
# reassociation is permitted.
for x in A:
acc += np.sqrt(x)
return acc
A_small = np.arange(1_000, dtype=np.float64)
A_big = np.arange(1_000_000, dtype=np.float64)
In [1]: sum_fast(A_small) # interpreter
/Users/guilhermeleobas/git/numba-smartjit/smartjit.py:45: NumbaInterpreterModeWarning: sum_fast not using JIT
warnings.warn(msg, NumbaInterpreterModeWarning)
Out[1]: 21065.833110879048
In [2]: sum_fast(A_big) # will trigger jit compilation + execution
Out[2]: 666666166.4588218
In the example above, calling sum_fast
with a A_big
triggered jit compilation, whereas calling with A_small
didn’t.
One important thing to notice is, after sum_fast
is compiled for A_big
, calling sum_fast
again for A_small
will now call the jitted version of sum_fast
, since now there is an overload that matches the provided argument:
In [3]: sum_fast.signatures
Out[3]: [(array(float64, 1d, C),)]
In [4]: sum_fast(A_small)
Out[4]: 21065.83311087906
It is also possible to provide signatures ahead-of-time to the @jit
decorator:
from smart_jit import jit, Action
def use_jit(a, b):
# fallback to interpreter mode
return Action.INTERPRETER
@jit(['int64(int64, int64)', 'float64(float64, float64)'],
use_jit=use_jit, warn_on_fallback=True)
def add(a, b):
return a + b
In [1]: add.signatures
Out[1]: [(int64, int64), (float64, float64)]
In [2]: add(2, 3)
Out[2]: 5
In [3]: add(2.2, 4.4)
Out[3]: 6.6000000000000005
Calling with a type that was not specified before will use the behavior returned by the use_jit
function.
In [4]: add('hello', ', world')
/Users/guilhermeleobas/git/numba-smartjit/smart_jit.py:62: NumbaInterpreterModeWarning: add(unicode_type, unicode_type) not using JIT
warnings.warn(msg, NumbaInterpreterModeWarning)
Out[4]: 'hello, world'
In [5]: add.signatures
Out[5]: [(int64, int64), (float64, float64)]
This differs from other other decorators in Numba, which raises a TypeError
when a matching error happens.
from numba import njit
@njit('int32(int32, int32)')
def fn(a, b):
return a
In [1]: fn('hello', 'world')
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 1
----> 1 fn('hello', 'world')
File ~/git/numba/numba/core/dispatcher.py:703, in _DispatcherBase._explain_matching_error(self, *args, **kws)
700 args = [self.typeof_pyval(a) for a in args]
701 msg = ("No matching definition for argument type(s) %s"
702 % ', '.join(map(str, args)))
--> 703 raise TypeError(msg)
TypeError: No matching definition for argument type(s) unicode_type, unicode_type
It is possible to raise an exception when use_jit
is called with unexpected types. This can be achieved by returning Action.RAISE_EXCEPTION
from the callback:
from smart_jit import smart_jit, Action
def use_jit(a):
if isinstance(a, int):
return Action.JIT
elif isinstance(a, str):
return Action.RAISE_EXCEPTION
else:
return Action.INTERPRETER
@smart_jit(use_jit=use_jit, warn_on_fallback=True)
def double(a):
return a + a
In [1]: double(3)
Out[1]: 6
In [2]: double(4.4)
/Users/guilhermeleobas/git/numba-smartjit/smart_jit.py:62: NumbaInterpreterModeWarning: double(float64) not using JIT
warnings.warn(msg, NumbaInterpreterModeWarning)
Out[2]: 8.8
In [3]: double.signatures
Out[3]: [(int64,)]
In [4]: double('hello')
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 double('hello')
File ~/git/numba-smartjit/smart_jit.py:133, in SmartJitDispatcher.__call__(self, *args, **kwargs)
131 return self._fallback_interpreter(*args, **kwargs)
132 elif jit_action == Action.RAISE_EXCEPTION:
--> 133 self._explain_matching_error(*args, **kwargs)
134 else:
135 msg = (
136 'Invalid value returned from "use_jit" keyword. Expected '
137 'one of "INTERPRETER, JIT_COMPILER, RAISE_EXCEPTION" '
138 f'but got "{jit_action}"'
139 )
File ~/git/numba/numba/core/dispatcher.py:703, in _DispatcherBase._explain_matching_error(self, *args, **kws)
700 args = [self.typeof_pyval(a) for a in args]
701 msg = ("No matching definition for argument type(s) %s"
702 % ', '.join(map(str, args)))
--> 703 raise TypeError(msg)
TypeError: No matching definition for argument type(s) unicode_type
If present, cached functions are loaded on demand. When executing a function, smart_jit
will check if there is a function in cache that matches the signature before calling use_jit
.
from smart_jit import jit, Action
def use_jit(a):
print(f'called "use_jit" with {a}')
return Action.JIT
@jit(use_jit=use_jit, cache=True)
def incr(a):
return a + 1
Calling for the first time will trigger JIT compilation and caching:
$ ipython -i example.py
In [1]: incr(4)
called "use_jit" with <class 'int'>
Out[1]: 5
Calling the same function a second time will use the cached overload:
$ ipython -i example.py
In [1]: incr(4)
Out[1]: 5
In [2]: # But only if the signature was previously cached
In [3]: incr(1.23)
called "use_jit" with <class 'float'>
Out[3]: 2.23
It is possible to track wether a function is using jit compilation/execution with the help of event listeners. Numba provides an API for listening to certain events that happens inside the compiler. For the @smart_jit
work, I’ve implemented two new event kinds (jit_execution
and interpreter_execution
) that are notified when jit or interpreter execution happens. Example:
from smart_jit import jit, Action
from numba.core import event
class CustomListener(event.Listener):
def on_start(self, event):
print(f'Start {event.kind}...')
def on_end(self, event):
print(f'End {event.kind}...')
def int_jit(a):
if isinstance(a, int):
return Action.JIT
return Action.INTERPRETER
@jit(use_jit=int_jit)
def incr(a):
return a + 1
In [1]: listener = CustomListener()
...: with event.install_listener("jit_execution", listener):
...: incr(4)
...:
Start jit_execution...
End jit_execution...
Calling incr
with a float value will not trigger the jit_execution
event, but will trigger interpreter_execution
:
In [2]: with event.install_listener("jit_execution", listener):
...: incr(1.23)
...:
In [3]: with event.install_listener("interpreter_execution", listener):
...: incr(1.23)
...:
Start interpreter_execution...
End interpreter_execution...
All limitations of Numba @jit
persist.