hgrecco / numbakit-ode

Leveraging numba to speed up ODE integration
Other
68 stars 3 forks source link

Speed up solver instantiation #5

Open hgrecco opened 4 years ago

hgrecco commented 4 years ago

Instantiating the solver takes quite long due to compilation. It would be interesting to improve de code (without loosing clarity) so that numba is able to use the compiled code cache.

hgrecco commented 4 years ago

With current Numba (0.51.2) the following code:

import numba as nb

@nb.njit(cache=True)
def func(t, y):
    return y

@nb.njit(cache=True)
def stepper(f, a, b):
    return 3 * f(a, b)

print(stepper(func, 4., 2.))

fails with the following error:

Traceback (most recent call last):
[...]
TypeError: cannot pickle 'weakref' object

due to cache=True.

If this works in the near future, instantiation time can drop dramatically as compilation will be done once. This might require some code reorganization.

For example, right now the stepper is organized in two layers:

def step_builder(*outer_args):
    """Build a stepper.

    This outer function should only contains attributes
    associated with the solver class not with the solver instance.
    """

    @numba.njit
    def _step(*inner_args):
        """Perform a single step.

        This inner function should only contains attributes
        associated with the solver instance not with the solver class.
        """

        # code to step

    return _step

The step_builder is called at class instantiation, _step at each step. This separation makes the code very clear and organized. But it might not work and the code will need to be flattened to something like:

@numba.njit
def _step(*inner_args, *outer_args):
    """Perform a single step.

    This outer function should only contains attributes
    associated with the solver class not with the solver instance.
    """
hgrecco commented 3 years ago

Relevant numba issues:

hgrecco commented 3 years ago

Relevant numba issues:

maurosilber commented 3 years ago

Some other cache-related maybe-relevant numba issues:

Some are quite old, and might be duplicates from newer issues as the ones mentioned above.

Illviljan commented 1 year ago

The failing example appears to work now:

import numba as nb

@nb.njit(cache=True)
def func(t, y):
    return y

@nb.njit(cache=True)
def stepper(f, a, b):
    return 3 * f(a, b)

print(f"{nb.__version__ = }")
print(stepper(func, 4., 2.))
nb.__version__ = '0.56.4'
6.0
maurosilber commented 1 year ago

But it is still producing a different cached version of stepper on each run:

> python cache.py && ls __pycache__/cache.stepper*.nbc | wc -l
1
> python cache.py && ls __pycache__/cache.stepper*.nbc | wc -l
2

Environment:

libllvm11                 11.1.0               hfa12f05_5    conda-forge
llvmlite                  0.39.1          py310h1e34944_1    conda-forge
numba                     0.56.4          py310h3124f1e_0    conda-forge
python                    3.10.8          h3ba56d0_0_cpython    conda-forge