Performance issue with SDE solver #517

Open pierreguilmin opened 3 days ago

pierreguilmin commented 3 days ago


When solving the (trivial) SDE $d y_t = -y_t\ dt + 0.2\ dW_t$, the Diffrax Euler solver is ~200x slower than a naive for loop. Am I doing something wrong? The speed difference is consistent across various SDEs, solvers, time steps dt, and number of trajectories, and it appears to be specific to SDE solvers.

import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt

# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2

# === diffrax euler
brownian_motion = dx.VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(ts=jnp.linspace(t0, t1, ndt))

def diffrax_simu():
    return dx.diffeqsolve(terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat).ys

# === homemade euler
def homemade_simu():
    dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))

    def step(y, dW):
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
        return y + dy, y

    return jax.lax.scan(step, 1.0, dWs)[-1]

# === plot a single trajectory
y = diffrax_simu()
y = homemade_simu()

# === benchmark
%timeit diffrax_simu().block_until_ready()
%timeit homemade_simu().block_until_ready()
5.39 ms ± 261 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 μs ± 899 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
lockwo commented 3 days ago

I get them to be a lot closer by using UnsafeBrownianPath, which has less overhead than VBT. Diffrax is still a bit slower with this change on my machine, but the difference is smaller (and probably due to other overheads that diffrax does to enable more features).

There's also some risky (but often useful) changes to UBP we've made internally that I've been meaning to put in the fork, so you can definitely do a fair amount with modifications to UBP (being able to get through all 3 stated requirements).

patrick-kidger commented 3 days ago

Yup, VBT is often the cause of poor SDE performance. Really we need some kind of LRU caching to make it behave properly, but that doesn't seem to be easy in JAX -- I'm pretty sure it'd require both a new primitive ('cached_call_p') and a new transform. That's a fairly advanced project for someone to take on!

In the meantime I recommend UBP as the go-to for these kinds of normal 'just solve an SDE' applications.

lockwo commented 3 days ago

I think a lot of people get turned off by the Unsafe in the name, maybe worth adding a sentence like this to the docs ("In the meantime I recommend UBP as the go-to for these kinds of normal 'just solve an SDE' applications.").

gautierronan commented 3 days ago

Thanks. Indeed using UBP does help but I understand it's quite restricted in terms of usage.

Diffrax is still a bit slower with this change on my machine, but the difference is smaller (and probably due to other overheads that diffrax does to enable more features).

It seems there is still a factor ~10-20 difference (irrespective of number of time steps) between the homemade solver and diffrax with UBP. I would have naively thought that any irrelevant computation would be jitted away. Could you elaborate on what diffrax with UBP does compared to the naive solver?

Diffrax (VBT): 7.51 ms ± 18.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Diffrax (UBP): 637 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Naive:         28.5 µs ± 147 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
lockwo commented 2 days ago

Diffrax has a lot more checking/shaping/logging than the default implementation. You can see it reflected in the jaxprs:

I believe most of this comes from the UBP, since if I do

def homemade_simu():
    ts = jnp.linspace(t0, t1, ndt)

    def step(y, t):
        dw = brownian_motion.evaluate(t, t + dt)
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
        return y + dy, y

    return jax.lax.scan(step, 1.0, ts)[-1]

I see the times are pretty much the same. Perhaps this does indicate that there is room for cutting down the speed costs of the UBP related overhead.

patrick-kidger commented 2 days ago

FWIW I think the speed difference here does seem unacceptably large. This seems like it should be improved.

Starting with the low-hanging fruit to be sure we're doing more of an equal comparison: can you try setting EQX_ON_ERROR=nan and diffeqsolve(throw=False), to disable all error checks. Those are fairly slow.

Also, can you try using stepsize_controller=StepTo(...). By default Diffrax does not recompile if the number of steps changes (e.g. because t1 changes), but a lax.scan implementation does. Diffrax pays a small amount of runtime cost for this generality. Using StepTo instead bakes in the discretisation in the same way as a lax.scan.

lockwo commented 2 days ago

With throw=False, EQX_ERROR=NAN and step to, this is what I see

code ```python import os os.environ["EQX_ON_ERROR"] = "nan" import diffrax as dx import jax import jax.numpy as jnp from matplotlib import pyplot as plt # === simulation parameters key = jax.random.PRNGKey(42) t0 = 0 t1 = 1 y0 = 1.0 ndt = 101 dt = (t1 - t0) / (ndt - 1) drift = lambda t, y, args: -y diffusion = lambda t, y, args: 0.2 steps = jnp.linspace(t0, t1, ndt) brownian_motion = dx.UnsafeBrownianPath(shape=(), key=key) solver = dx.Euler() terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion)) saveat = dx.SaveAt(steps=True) @jax.jit def diffrax_simu(): return dx.diffeqsolve(terms, solver, t0, t1, dt0=None, y0=y0, saveat=saveat, adjoint=dx.DirectAdjoint(), throw=False, stepsize_controller=dx.StepTo(ts=steps)).ys @jax.jit def homemade_simu(): dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,)) def step(y, dW): dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW return y + dy, y return jax.lax.scan(step, 1.0, dWs)[-1] y = diffrax_simu().block_until_ready() plt.plot(y) y = homemade_simu().block_until_ready() plt.plot(y) %timeit _ = diffrax_simu().block_until_ready() %timeit _ = homemade_simu().block_until_ready() ```

(diffrax top, custom bottom)

2.18 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 109 µs ± 25.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

(without any of those things I had):

2.43 ms ± 666 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 110 µs ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

(all on CPU, just a slower CPU, but the 20-30x slowdown seems of the same scale)

patrick-kidger commented 2 days ago

So you definitely don't want DirectAdjoint: this is actually really slow and should be avoided if possible. (It exists to handle some autodiff edge cases, I'd love to remove it sometime...) Use the default instead.

Make sure you include an argument (say y0) to both jitted functions -- XLA may have different behavior around constant folding.

I'd also try with and without SaveAt(steps=True). (And adjusting the scan appropriately.) I think this should be equivalent either way but I'm not 100% certain.

With all of the above in, then at that point there shouldn't actually be that much difference between the two implementations. (And if there is then we should figure out what.)

lockwo commented 2 days ago

The default actually errors with UBP which is why I changed to direct adjoint

ValueError: `adjoint=RecursiveCheckpointAdjoint()` does not support `UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` instead.
patrick-kidger commented 2 days ago

Ah, right. I've just checked and in the case of an unsafe SDE we do actually arrange for DirectAdjoint to do a scan so that should be fine:

(In retrospect I think we could have arranged for the default adjoint to also do the same thing, that might be a small usability improvement.)

Anyway, that's everything off the top of my head -- I might be forgetting something but with these settings then I think Diffrax should be doing something similar to the simple lax.scan implementation. But clearly we're missing something!

(EDIT: we still have one discrepancy I have just noticed: generating the Brownian samples in advance vs on-the-fly.)

If you'd like to dig into this then it might be time to stare at some jaxprs or HLO for the two programs. If you want to do this at the jaxpr level then you might find eqxi.finalise_jaxpr(and friends) to be a useful set of tools here:

Many primitives exist just to add e.g. an autodiff rule, so we can simplify our jaxprs down to what actually gets lowered by ignoring that and tracing through their impl rules instead.

lockwo commented 2 days ago

DirectAjoint does slow things down, but not all the way. If I switch to a branch that allows for UBP + recursive adjoint, it's faster but still around ~4x gap. If I account for the fact that UBP has to split keys but the other doesn't, I get the gap to be around ~1.1-1.2 (which maybe isn't ideal, but seems much more reasonable to me given there's probably some other if statements/logging that might exist).

x = Timer(lambda : diffrax_simu(y0).block_until_ready())
x = Timer(lambda : homemade_simu(y0).block_until_ready())

with (above things, NAN, steps, function input, stepto, max steps, etc. all that) and direct adjoint: 0.002462916076183319 0.0005935421213507652

w/ checkpoint adjoint (on an internal branch that had some UBP changes to work with checkpoint): 0.002062791958451271 0.0005716248415410519

w/ both splitting keys: 0.0019747079350054264 0.001669874880462885

(code changed to:

def homemade_simu(yy):

    def step(y1, dW):
        y, k = y1
        k, subkey = jax.random.split(k)
        dw = jnp.sqrt(dt) * jax.random.normal(subkey)
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
        return (y + dy, k), y

    return jax.lax.scan(step, (yy, key), steps)[-1]


patrick-kidger commented 19 hours ago

Aha, interesting! Good to have more-or-less gotten to the bottom of the cause of this.


  1. I'd be curious to see what your version of RecursiveCheckpointAdjoint does, and how that compares to the unsafe-SDE-branch of DirectAdjoint.
  2. I suppose generating the Brownian samples in advance, rather than on-the-fly, is very plausibly much faster. (Although I note that it will be more memory-intensive.) Off the top of my head I'm not immediately sure how to arrange it so that the case of using a constant step size controller and an UnsafeBrownianPath could make it possible to precompute things.

On point 2, I suspect the solution may require allowing the control to have additional state. (Which is also what we'd need to make VBT faster.) Perhaps it's time to bite that bullet and allow for that to happen. Happy to hear suggestions on this one!

lockwo commented 11 hours ago
  1. That is something I want to investigate as well (and also organize more of it pushed to a fork for others to check), admittedly will take a little bit for me to get to
  2. Would it be possible to add a "precompute" flag (or something to that effect) to UBP? Which would generate the noise ahead of time (and the size is just determined by the max steps or user input), without requiring a stateful approach. This might(?, if the dt multiplication is still done in the loop) also be compatible with adaptive solver that don't reject steps ("previsible" I think James calls them).
  3. I am in general an advocate of stateful controls (also discussed in #490), although I haven't thought much more on it since the discussion in that issue (which is very similar to how my stateful UBP is implemented).
patrick-kidger commented 1 hour ago
  1. Okay, lmk what you find.
  2. I'm not sure. The way the controls are called at the moment is with the t, not the step index. We'd also have to have a way to pass the number of steps etc to the control. FWIW I'd probably lean towards not having a flag and just always doing this when possible.
  3. I think to do this 'properly' we might need to have AbstractSolver.step also accept the control state, and then pipe it through appropriately. Then also return the updated state. Unfortunately I think we're looking at a hard break to both the control and the solver APIs here, but c'est la vie.