patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.46k stars 133 forks source link

[question] Computational complexity of integrating / backpropogating through SDE #186

Open ciupakabra opened 2 years ago

ciupakabra commented 2 years ago

Hi, I was a bit surprised by the difference in speed between backpropogating through ODEs vs SDEs but couldn't find any discussion in the documentation about the time complexity of either. With a quick look, I couldn't find any issues addressing this. In particular, consider this piece of code, where we vary the different terms being integrated:

import optax
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
import diffrax as dx
from tqdm import tqdm

def sample(model, num_steps, gamma, dim, key):
    drift, diff = model

    def f(t, y, args):
        return drift(jnp.concatenate([t[None], y]))

    def g(t, y, args):
        return diff(jnp.concatenate([t[None], y]))**2

    control = dx.VirtualBrownianTree(
        t0=0,
        t1=1,
        tol=1/(2*num_steps),
        shape=(dim,),
        key=key,
    )

    drift_term = dx.ODETerm(f)
    diffusion_term = dx.WeaklyDiagonalControlTerm(g, control)
    terms = dx.MultiTerm(drift_term, diffusion_term)
    solver = dx.Euler()
    y0 = jnp.zeros(dim)

    ts = jnp.linspace(0, 1, num_steps + 1)
    saveat = dx.SaveAt(ts=ts)

    sol = dx.diffeqsolve(
        terms,
        # diffusion_term,
        # drift_term,
        solver,
        0,
        1,
        1/num_steps,
        y0,
        saveat=saveat,
        max_steps=num_steps + 1,
    )

    return sol.ys

def loss(drift, num_steps, gamma, dim, key):
    path = sample(drift, num_steps, gamma, dim, key)
    final = path[-1]
    loss = jnp.sum(final**2)
    return loss

@eqx.filter_value_and_grad
def loss_mean(drift, num_steps, gamma, dim, key, batch_size):
    loss_vmapped = jax.vmap(loss, (None, None, None, None, 0), 0)
    key = jrandom.split(key, batch_size)
    return jnp.mean(loss_vmapped(drift, num_steps, gamma, dim, key))

if __name__=="__main__":

    key = jrandom.PRNGKey(0)

    init_drift_key, init_diff_key, train_key = jrandom.split(key, 3)

    dim = 500

    drift = eqx.nn.MLP(dim + 1, dim, 300, 2, key=init_drift_key)
    diff = eqx.nn.MLP(dim + 1, 1, 300, 2, key=init_diff_key)

    model = (drift, diff)

    optimizer = optax.adamw(1e-4)
    opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def make_step(model, num_steps, gamma, dim, key, batch_size, opt_state):
        loss, grads = loss_mean(model, num_steps, gamma, dim, key, batch_size)
        updates, opt_state = optimizer.update(
            grads, opt_state, eqx.filter(model, eqx.is_inexact_array)
        )
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for step in tqdm(range(100)):
        step_key = jrandom.fold_in(train_key, step)
        loss, model, opt_state = make_step(
            model, 20, 0.1, dim, step_key, 32, opt_state
        )

On my machine locally (macbook m1) integrating terms or diffusion_term takes around 41s and integrating drift_term takes around 5s. What is the reason for this difference? Am I doing something wrong here? Note, that the computation in the diffusion term is simply multiplying BM by a scalar. Is VirtualBrownianTree the slow part here? I suspect implementing the euler solver in plain jax would give a faster solution -- would that be wrong? Maybe it's worth adding some documentation about this.

patrick-kidger commented 2 years ago

How does it change if you switch to control = UnsafeBrownianPath? What about if you switch to diffeqsolve(..., stepsize_controller=ConstantStepSize(compile_steps=True))?

Putting both of those together should be essentially identical to just writing out the solver in plain JAX.

EDIT: actually, I think UnsafeBrownianPath doesn't support backpropagation here, which it probably should. Let me have a look at this.

patrick-kidger commented 2 years ago

Ech, looks like this runs afoul of #176 as well. Which version of JAX are you using? For the purposes of debugging this I'll downgrade. (And #176 will be fixed just as soon as https://github.com/google/jax/pull/13062 lands.)

ciupakabra commented 2 years ago

Hmm, my versions are:

jax                        0.3.14
jaxlib                     0.3.14

which seem to fix #176 but still give the behaviour described above. I've tried different versions of diffrax but not jax. Any particular jax version you recommend I should try?

EDIT: with stepsize_controller=ConstantStepSize(compile_steps=True) the speeds improve but the difference is still quite big: integrating terms and diffusion_term takes 26s, while drift_term takes 3s.

2nd EDIT:

Disabling this check

https://github.com/patrick-kidger/diffrax/blob/b8475527eacdba81328130edd7dadd08a0b34063/diffrax/integrate.py#L687-L690

and using UnsafeBrownianPath for the control has the same problem

patrick-kidger commented 2 years ago

Okay, I think I've tracked this down. It's because without the diffusion term, your computation is actually unbatched: your only batched input is key, but this is unused. So JAX optimises the computation to only run the diffeq for a single batch element, then just broadcasts it out at the end.

But if you have the diffusion term, then you're actually solving 32 different equations at the same time, and the same optimisation can't be made.

ciupakabra commented 2 years ago

How come for a straightforward implementation of the solver JAX manages to optimize this? Changing the sample function to the below runs in 5s. (This is with the same versions of jax, jaxlib and diffrax as before -- upgrading all three to most recent versions gives quite a big performance boost in all scenarios (any ideas why?) but the differences are still similar: diffrax ODE term takes 2s, diffrax SDE w VirtualBrownianTree takes 15s with compile_steps=True / 24s without, diffrax SDE w UnsafeBrownianMotion takes 10s with compile_steps=True / 14s without and the below implementation still takes 5s.)

def euler_maruyama(drift, diff, num_steps, t0, t1, y0, key):

    ts = jnp.linspace(0, 1, num_steps + 1)[:-1]
    keys = jrandom.split(key, num_steps)
    dt = 1 / num_steps

    def _body_fun(y, args):
        t, key = args
        w = jrandom.normal(key, y.shape)
        _y = y + dt * drift(t, y, None) + jnp.sqrt(dt) * diff(t, y, None) * w
        return _y, _y

    _, path = jax.lax.scan(_body_fun, y0, (ts, keys))

    path = jnp.concatenate([y0[None, ...], path], axis=0)

    return path

def sample(model, num_steps, dim, key):
    drift, diff = model

    def f(t, y, args):
        return drift(jnp.concatenate([t[None], y]))

    def g(t, y, args):
        return diff(jnp.concatenate([t[None], y]))**2

    y0 = jnp.zeros(dim)

    path = euler_maruyama(
        f,
        g,
        num_steps,
        0.0,
        1.0,
        y0,
        key,
    )

    return path

I was assuming that the problem was because the sampling in VirtualBrownianTree is somewhat complicated, but the same happens with UnsafeBrownianMotion as well (after disabling the check as in my previous comment) -- so why is JAX unable to optimize it as it does with the above implementation?

Okay, I think I've tracked this down.

Did you do this by just inspecting jaxprs?

patrick-kidger commented 2 years ago

Upgrading giving a performance boost -- this is because I've been working to improve the efficiency of Diffrax :)

(And there's more stuff coming in just over the horizon: leave a comment over on https://github.com/google/jax/pull/13184 if you want.)

Yeah, the complexity of sampling the Brownian motion was my first thought as well. (Stuff like key-splitting and key-folding-in is also not that cheap, and was my second thought.) JAX manages to optimise your implementatio for the same reason as it's able to optimise the Diffrax version. You are only vmaping wrt key. If you remove the diffusion term then you're not using key at all. So the computation is no longer vmapd. You go from solving a batch of diffeqs to just solving a single diffeq.

patrick-kidger commented 2 years ago

In terms of how I tracked this down: nope, no jaxprs. They're sadly not that helpful for debugging anything Diffrax-related. Differential equation solvers are large and complicated enough that they end up being too large to really interpret. (EDIT: although this should be simplified a lot once https://github.com/google/jax/pull/13062 lands.)

In this case I tracked things down by substituting VirtualBrownianTree for a simple class Control: def evaluate(self, t0, t1, left=True): return t1 - t0 (which produced the fast version) and then bisecting the differences between Control and VirtualBrownianTree until I tracked down what the issue was.

If you want to induce a batch dependence without any of this, try doing y0 = y0 + 0.0001 * key[0].

ciupakabra commented 2 years ago

Thanks for the quick response!

Are you sure this difference comes only from vmap'ing? The custom implementation above does depend on the key (since we're solving the full SDE and sampling noise), and doing y0 = y0 + 0.0001 * key[0] does not change the speed. However, it's still faster than using diffrax with either UnsafeBrownianMotion or VirtualBrownianTree.

I agree that when using diffrax the difference between integrating the drift_term and both terms is partly because of the vmap. But even with y0 = y0 + 0.0001 * key[0] integrating drift_term is twice as fast as integrating both terms with VirtualBrownianTree.

ciupakabra commented 2 years ago

But even with y0 = y0 + 0.0001 * key[0] integrating drift_term is twice as fast as integrating both terms with VirtualBrownianTree

I realized that this could be simply because we are not evaluating the diff network. When we change g to

    def g(t, y, args):
        return 1.0

there's little difference when using VirtualBrownianTree. But with this diffusion, integrating both terms with UnsafeBrowianMotion is the same as integrating just the drift_term but with y0 = y0 + 0.0001 * key[0]. So it seems that the performance difference is mainly in the fact that VirtualBrownianTree is expensive + integrating drift_term did not vmap because there was no dependence on the key.

Perhaps it then makes sense to allow differentiating through UnsafeBrownianMotion in certain cases? I'm happy to submit a PR for this, but I'm not sure why it's disabled in the first place -- is it because force_bitcast_convert_type not very stable?

Regarding the difference between the custom implementation above and diffrax with UnsafeBrownianMotion -- could it be that jax.lax.scan is simply faster than bounded_while_loop? Would it make sense to incorporate something like exact_num_steps next to max_steps to use a scan when number of steps is known exactly (as it is in this case)?

patrick-kidger commented 2 years ago

So yeah, there's some definite differences here due to whether we're evaluating just drift, just diff, or both. In particular these two networks are also of different sizes, and I noticed that this also produced a measurable speed difference when doing drift-vs-diffusion comparisons. (Controlling for everything else.)

FWIW I didn't notice VirtualBrownianTree adding very much in the way of overhead (~5% relative to a dummy control, once again with every other source of variation fixed), but perhaps YMMV.

The difference between your implementation and bounded_while_loop: indeed, the extra complexity of bounded_while_loop (which is reduced but not completely eliminated when using compile_steps=True) is probably responsible for a fair bit of overhead here. This difference should be eliminated once https://github.com/google/jax/pull/13062 and https://github.com/google/jax/pull/13184 land -- and even better, these should eliminate the overhead even in the case that the number of steps isn't known exactly! The best thing you can do is to leave a comment over on those +1-ing them, to let the JAX maintainers know this is a priority for you.

Differentiating through UnsafeBrownianMotion: this is disabled because of the checkpointing happening inside bounded_while_loop. (A necessary part of its magic.) However the reconstruction from the checkpoints may not be perfect down to the least significant floating point bit -- e.g. GPU convolutions are nondeterministic -- and then the use of the floating point number as a PRNG key would result in arbitrarily large changes afterwards. Once again this is an oddity that should vanish once the above PRs land, and we will be able to differentiate with UnsafeBrownianMotion.

ciupakabra commented 2 years ago

Got it, makes sense! Will leave this open until those PRs roll out