Open hgrecco opened 4 years ago
With current Numba (0.51.2) the following code:
import numba as nb
@nb.njit(cache=True)
def func(t, y):
return y
@nb.njit(cache=True)
def stepper(f, a, b):
return 3 * f(a, b)
print(stepper(func, 4., 2.))
fails with the following error:
Traceback (most recent call last):
[...]
TypeError: cannot pickle 'weakref' object
due to cache=True
.
If this works in the near future, instantiation time can drop dramatically as compilation will be done once. This might require some code reorganization.
For example, right now the stepper is organized in two layers:
def step_builder(*outer_args):
"""Build a stepper.
This outer function should only contains attributes
associated with the solver class not with the solver instance.
"""
@numba.njit
def _step(*inner_args):
"""Perform a single step.
This inner function should only contains attributes
associated with the solver instance not with the solver class.
"""
# code to step
return _step
The step_builder
is called at class instantiation, _step
at each step. This separation makes the code very clear and organized. But it might not work and the code will need to be flattened to something like:
@numba.njit
def _step(*inner_args, *outer_args):
"""Perform a single step.
This outer function should only contains attributes
associated with the solver class not with the solver instance.
"""
Relevant numba issues:
Some other cache-related maybe-relevant numba issues:
Some are quite old, and might be duplicates from newer issues as the ones mentioned above.
The failing example appears to work now:
import numba as nb
@nb.njit(cache=True)
def func(t, y):
return y
@nb.njit(cache=True)
def stepper(f, a, b):
return 3 * f(a, b)
print(f"{nb.__version__ = }")
print(stepper(func, 4., 2.))
nb.__version__ = '0.56.4'
6.0
But it is still producing a different cached version of stepper
on each run:
> python cache.py && ls __pycache__/cache.stepper*.nbc | wc -l
1
> python cache.py && ls __pycache__/cache.stepper*.nbc | wc -l
2
Environment:
libllvm11 11.1.0 hfa12f05_5 conda-forge
llvmlite 0.39.1 py310h1e34944_1 conda-forge
numba 0.56.4 py310h3124f1e_0 conda-forge
python 3.10.8 h3ba56d0_0_cpython conda-forge
Instantiating the solver takes quite long due to compilation. It would be interesting to improve de code (without loosing clarity) so that numba is able to use the compiled code cache.