hgrecco / numbakit-ode

Leveraging numba to speed up ODE integration
Other
68 stars 3 forks source link

API design: prevent unneeded recompilation #28

Open maurosilber opened 2 years ago

maurosilber commented 2 years ago

From: https://github.com/hgrecco/numbakit-ode/issues/12#issuecomment-1099136091

The solver shouldn't recompile when only the initial conditions y0 change. But this fact is hidden by the current API, which accepts both jitted and non-jitted functions. In the latter case, it is jitted inside Solver.__init__, leading to a "different" function as far as numba is concerned, and triggering a recompilation of (some) parts of the Solver.

Instead, we could:

  1. raise an error when the passed function is not jitted
  2. document this behaviour somewhere, as it could still be "misused" as Solver(numba.njit(func), ...)

A similar issue happens with functions which depend on parameters (func(t, y, p)). A closure (rhs(t, y)) is generated inside Solver.__init__, which is considered a different function even if the same parameters are used.

We could provide a helper function to produce the closure outside Solver.__init__:

@numba.njit
def closure(func, params):
    @numba.njit
    def rhs(t, y):
        return func(t, y, params)
    return rhs

Later, we could change the internal implementation, changing func(t, y) to func(t, y, p), and passing an array p of parameters where needed.

hgrecco commented 2 years ago

I think we could start by raising a warning and then move to raise an error when we change to func(t, y, p)