hgrecco / numbakit-ode

Leveraging numba to speed up ODE integration
Other
68 stars 3 forks source link

Dense output is less smooth than expected #11

Closed astrojuanlu closed 3 years ago

astrojuanlu commented 3 years ago

We are considering switching to numbakit-ode in poliastro https://github.com/poliastro/poliastro/issues/1042 however we'd like to ask if support of dense output and events is within the scope of the package. I acknowledge that it can be challenging to do this in numba though - a "no, we won't support this in the near future" answer is also acceptable :)

hgrecco commented 3 years ago

I think we need to separate the discussion between the two things.

Regarding the dense output, in a way numbakit-ode already provide a "dense output" just not as a separate object. Briefly, you can evaluate at any t with and the solver will provide an optimally interpolated result depending on the method. This is done automatically under the hood when you use the run method.

Event are trickier but it would be nice to have also for the original purpose of the numbakit-ode (biological simulations). It would be nice to define the API first. Maybe we can move this discussion to another issue?

astrojuanlu commented 3 years ago

Fair enough - let me repurpose this issue then:

About dense output: perhaps I'm misusing the library or misunderstanding, but I don't see the desired effects. This code snippet for example:

import numpy as np
import matplotlib.pyplot as plt
import nbkode

def pop(t, z, p):
    x, y = z
    a, b, c, d = p
    return np.array([a*x - b*x*y, -c*y + d*x*y])

z0 = np.array([10, 5], dtype=float)
t0 = 0
p = np.array([1.5, 1, 3, 1])

solver8 = nbkode.DOP853(pop, t0, z0, params = p)

ts3 = np.linspace(0, 15, 300)
_, zs3 = solver8.run(ts3)

solver8 = nbkode.DOP853(pop, t0, z0, params = p)  # Avoids https://github.com/hgrecco/numbakit-ode/issues/12

ts3_alt = np.linspace(0, 15, 30_000)
_, zs3_alt = solver8.run(ts3_alt)

#plt.plot(ts3, zs3, 'x', mew=3)
plt.plot(ts3_alt, zs3_alt, '+', mew=1)
plt.xlabel('t')
plt.legend(['x', 'y'], shadow=True)
plt.title('Lotka-Volterra System')
plt.show()
plot_sampling plot_closeup
Full plot Closeup

Produces essentially the same output, regardless of the number of points that the t array has in the call to .run().

On the other hand, scipy.integrate.solve_ivp(..., dense_output=True) does have the desired effect, whereas scipy.integrate.solve_ivp(..., dense_output=False) is similar to what I obtain with numbakit-ode.

Tightening the relative and absolute tolerances to, say, 1e-9 and 1e-12, does produce a smooth solution in both cases. But this is not necessary in the scipy.integrate.solve_ivp(..., dense_output=True) case.

hgrecco commented 3 years ago

I am not able to reproduce. Running this:

_, zs3 = solver8a.run(ts3)
_, zs3_alt = solver8b.run(ts3_alt)

print(nbkode.__version__)
print(len(ts3), len(zs3))
print(len(ts3_alt), len(zs3_alt))

results in this

0.4
300 300
30000 30000

Note: I am running numbakit-ode 0.4, but I am pretty sure this behavior has not changed

astrojuanlu commented 3 years ago

Sure, the length of the arrays is consistent (as you showed) - but I don't see the smooth, interpolated solution I'd expect from the dense output (or at least, it doesn't match the visual result one obtains from specifying dense_output=True in SciPy). However, it might as well be that I have a misunderstanding on what dense output is.

hgrecco commented 3 years ago

A few things we already know:

  1. nbkode.DOP853 matches the scipy pure python implementation closely for the steps defined by the algorithm (step method, not run), and we have a test for that.
  2. The base solver class provides a default and simple _interpolate method. The idea is that each subclass should overrides it in a way which is more appropriate for the method (in the same way that the DenseOutput subclasses of scipy work). This has not been done yet fornbkode.DOP853 but after doing it will only change y, not t

I need to look at the solve_ivp implementation of SciPy, not sure if it matches the class based implementation.

maurosilber commented 3 years ago

Adding to the previous comment, there are two* types of interpolators for RungeKutta methods: some use only the calculated coefficients from the last step, and others, such as DOP853, need to add some extra coefficients. Now, we are only using the first 12 coefficients of DOP853, as the rest are for the interpolator. I'm not sure which would be a good way to implement this, as some methods (step and skip) don't need the interpolator, and hence the extra coefficients.

*for methods of order less than 3, we can use the values y0, y1, f(y0) and f(y1), which every method is already calculating and saving to a buffer, to perform an Hermite interpolation. Currently, the default interpolator is a linear interpolator.

Implementing dense output could be done by changing the "buffer" to have infinite capacity for methods that use the default interpolator. For methods that have a specialized interpolator, we would have to save the interpolating coefficients instead of the values of the derivative at each step.

hgrecco commented 3 years ago

I think we need to distinguish between the interpolator object returned to the user and the interpolation done to put values in the returned vector.

In SciPy solve_ivp (if I understand correctly) what is returned to the user (in OdeSolution.sol if dense_output=True) is a split interpolant. Briefly, it contains a series of interpolants that together cover the whole range between t0 and tf. These interpolants are used to calculate the values that will be returned to the user in (t, y) to t in t_eval.

In numbakit-ode we do not output nor generate this split interpolant. But we aim to do the equivalent of t_eval. As mentioned by @maurosilber be now have a simple linear interpolant but the infrastructure is there to override in each Solver class a better one. My impression is we should do this an defer the decision to build a split interpolant for the future.

hgrecco commented 3 years ago

Today I spend a few minutes implementing the "correct" interpolator for RK45. It was dead easy, just adapted some code out of SciPy. Briefly, I added and extra return value (P) to _step_args, fixed _step to accept it and implemented _interpolate.

+++ b/nbkode/runge_kutta/explicit.py
@@ -87,7 +87,7 @@ class FSAL(AdaptiveRungeKutta, ERK, abstract=True):
         step_update = cls._step_update

         @numba.njit
-        def _step(rhs, cache, h, K, options):
+        def _step(rhs, cache, h, K, options, *args):
             while True:
                 t, y = fixed_step(rhs, cache, h, K)
                 error = step_error(h, K, E)
@@ -225,6 +225,36 @@ class RungeKutta45(FSAL):
         [-71 / 57600, 0, 71 / 16695, -71 / 1920, 17253 / 339200, -22 / 525, 1 / 40]
     )

+    @property
+    def _step_args(self):
+        return super()._step_args + (self.P,)
+
+    @staticmethod
+    @numba.njit()
+    def _interpolate(t_eval, rhs, cache, *args):
+        h, K, options, P = args
+        Q = K.T.dot(P)
+        t_old = cache.ts[-2]
+        y_old = cache.ys[-2]
+        t = cache.ts[-1]
+        delta = t - t_old
+        order = Q.shape[1] - 1
+        t_eval = np.asarray(t_eval)
+        x = (t_eval - t_old) / delta
+        if t_eval.ndim == 0:
+            p = np.tile(x, order + 1)
+            p = np.cumprod(p)
+        else:
+            p = np.tile(x, (order + 1, 1))
+            p = np.cumprod(p, axis=0)
+        y = delta * np.dot(Q, p)
+        if y.ndim == 2:
+            y += y_old[:, None]
+        else:
+            y += y_old
+
+        return y
+

But there is a catch. In this implementation certain preparation operations are done each time the interpolator is called while they could be done only when the interpolator is built. For this particular case, is just Q = K.T.dot(P), but I wonder if there are other interpolators in which this is more expensive.

To give some context, if the user does not require the interpolator as an output, but it asks for specific timepoints to evaluate (t_eval) an interpolator is built for each step if there is it least one t_eval within that step and the previous. It is reasonable to think that an interpolator will be called multiple times within the same step.

So I see three option:

  1. Use the approach shown here
  2. Create a cache for interpolator and pass it around.
  3. Create a njitted class which has the cache and the interpolator. (like SciPyDenseOutput, but lets use another name if we go for this option)

I think we should avoid (2) as it will lead to unmaintable code. (1) is probably faster when a few intersteps time points are calculated, (3) will be in the opposite case as it will need to instantiate an object of that jitclass. (3) will allow creating in the future a piecewise interpolator and provide it to the user.

hgrecco commented 3 years ago

We need to go for (1) as creating an object is way too slow for the time being:

import numba as nb

class Interpolator:

    def __init__(self, x1, y1, x2, y2):
        self.x1 = x1
        self.y1 = y1
        self.x2 = x2
        self.y2 = y2
        self.m = (y2 - y1) / (x2 - x1)

    def evaluate(self, x):
        return self.y1 + self.m * (x - self.x1)

JInterpolator = nb.jitclass([
    ("x1", nb.types.float64),
    ("y1", nb.types.float64),
    ("x2", nb.types.float64),
    ("y2", nb.types.float64),
    ("m", nb.types.float64),
])(Interpolator)

def evaluate(x, x1, y1, x2, y2):
    m = (y2 - y1) / (x2 - x1)
    return y1 + m * (x - x1)

jevaluate = nb.njit(evaluate)
%%timeit
jip = JInterpolator(1., 1., 10., 9.)

12.8 µs ± 118 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%%timeit
jip.evaluate(5.)

744 ns ± 6.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
jevaluate(5., 1., 1., 10., 9.)

200 ns ± 8.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

hgrecco commented 3 years ago

I have a new approach that combines all the above in a way that I think is maintainable and performant.

Briefly:

  1. We create an Interpolator object for RK23, RK45, DOP853 and others.
  2. We pass this intepolator around in _step_args
  3. This interpolator is updated in place every time is needed (i.e. no new copies are made)

This interpolator is stored in a class variable name _interpolator which is defined in core.Solver (although it is not necessary for all)

The interpolators are adapted from SciPy with a few important differences

  1. They rely strongly on the fact that many variables (e.g. the cache, K) are updated in place.
  2. They can only move forward.
  3. They are not bulletproof and therefore they should never become part of the public API (maybe we should prefix dem with _)

In addition, I think that if we replace the default linear interpolator by a spline that make use of the cache history it will be ok.

As you can see in the tests, the result match in general quite closely the (relative difference 1e-16) the SciPy results. There is still a difference observed for RK23 for larger times (probably due to a bug in the update function)

hgrecco commented 3 years ago

This does not change the public API. I will merge this over the weekend if there are no complaints.

hgrecco commented 3 years ago

I have merge this, if this is still a problem maybe we can restart it in a new, more specific, issue.