patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.42k stars 127 forks source link

Matching the performances of `jax.lax.scan` in adjoint calculation #274

Closed astanziola closed 1 year ago

astanziola commented 1 year ago

Hello again!

I am still on my quest to add proper integration + adjoint calculation (and checkpointing) to my wave simulator :smile:

I appreciate that diffrax offers a range of methods for calculating adjoints, each with its own trade-off between computational complexity and memory requirements.

However, for smaller simulations, it might be beneficial to maximize checkpoint usage and potentially save the entire ODE trajectory for reverse-mode AD. This approach takes full advantage of the GPU memory, thereby reducing computational times.

My understanding is that this can be achieved by using RecursiveCheckpointAdjoint with a large value of checkpoints, potentially as high as the number of steps in the forward integrator.

I've attempted to implement this without much success. To be precise, while I am obtaining the correct numerical results, the computation times are far longer than expected.

Here is an MWE:

from diffrax import (
    RecursiveCheckpointAdjoint, SemiImplicitEuler, ODETerm, 
    diffeqsolve, SaveAt, ConstantStepSize
)

from jax import numpy as jnp
from jax.lax import scan
from timeit import repeat
import jax

def _speedtest(integrator, grad_fn, fields, t, name):
    integration_compiled = jax.jit(integrator).lower(fields,t).compile()
    integration_times = repeat(lambda: integration_compiled(fields,t)[0].block_until_ready(), number=5, repeat=20)
    print(f"{name}: {min(integration_times)}")

    grad_fn_compiled = jax.jit(grad_fn).lower(fields,t).compile()
    scan_times = repeat(lambda: grad_fn_compiled(fields,t)[0].block_until_ready(), number=5, repeat=20)
    print(f"{name} AD: {min(scan_times)}")

##### SETUP #####
N = 256
N_steps = 2000
t = jnp.linspace(0, 1, N_steps)

u0, v0 = jnp.zeros((N, N)), jnp.zeros((N, N)).at[32,32].set(1.0)
fields = (u0, v0)

# Integration terms
du = lambda t, v, args: -(v**2)
dv = lambda t, u, args: -jnp.fft.irfft(jnp.sin(jnp.fft.rfft(u)))
sample = lambda t, y, args: y[0][64,64]  # Some arbitrary sampling function

##### INTEGRATE WITH scan #####
@jax.checkpoint
def scan_fun(carry, t):
    u, v, dt = carry
    u = u + du(t, v, None)*dt
    v = v + dv(t, u, None)*dt
    return (u, v, dt), sample(t, (u, v), None)

def integrator(fields, t):
    dt = t[1] - t[0]
    carry = (fields[0], fields[1], dt)
    _, values = scan(scan_fun, carry, t)
    return values

@jax.grad
def grad_fn(fields, t):
    return jnp.mean(integrator(fields, t)**2)

# Timing
_speedtest(integrator, grad_fn, fields, t, "scan")

##### INTEGRATE WITH SemiImplicitEuler #####
terms = ODETerm(du), ODETerm(dv)

def integrator(fields, t):
    return diffeqsolve(
        terms,
        solver=SemiImplicitEuler(),
        t0=t[0],
        t1=t[-1],
        dt0=t[1]-t[0],
        y0=fields,
        args=None,
        saveat=SaveAt(steps=True, fn=sample, dense=False),
        stepsize_controller=ConstantStepSize(),
        adjoint=RecursiveCheckpointAdjoint(checkpoints=N_steps),
        max_steps=N_steps
    ).ys

@jax.grad
def grad_fn(fields, t):
    return jnp.mean(integrator(fields, t)**2)

# Timing
_speedtest(integrator, grad_fn, fields, t, "SemiImplicitEuler")

Where I get the following timings on an RTX 4000:

scan: 0.20132129988633096
scan AD: 0.5393328319769353
SemiImplicitEuler: 0.426237307023257
SemiImplicitEuler AD: 3.552313446998596

As expected, for scan, the AD calculation is roughly twice the execution time required by the forward pass. This can be made almost exactly 2x if the jax.checkpoint decorator is removed.

For the forward pass of SemiImplicitEuler, the timings I get are approximately twice those of the scan alone. However, this could easily be attributed to the more sophisticated implementation of the diffrax integrator, so overall that's completely fine.

However, the timings for performing AD are about 7x those required by the scan method. In a more complex example within my simulator, it can reach up to 30x the time required by the equivalent scan integrator.

Am I missing something about the correct approach to calculating the adjoint?

Also, I'm not sure if the RecursiveCheckpointAdjoint is using the same solver as the forward integrator (based on my understanding of the documentation, it isn't), and I can't seem to find a way to pass a specific solver to it. Is it be necessary to define a new class derived from AbstractAdjoint with a custom loop method to achieve this?

Thanks a lot!

patrick-kidger commented 1 year ago

First of all, thankyou for providing such a careful benchmark. Second, sorry for taking so long to get back to you -- tackling this has turned out to be an interesting problem, which turned out to take longer than I thought! (This also turned up two XLA bugs along the way: https://github.com/google/jax/issues/16663, https://github.com/google/jax/issues/16661.)

Anyway, I'm happy to say that as of #276, the performance is now much closer. On my V100:

scan: 0.042259128065779805 
scan AD: 0.10428679105825722
SemiImplicitEuler: 0.04464115505106747 
SemiImplicitEuler AD: 0.15305515099316835

Note that I did make one change to the benchmark to ensure a fair comparison: I switched Diffrax's stepsize_controller to diffrax.StepTo, as this is the appropriate analogue of lax.scan. (In contrast, note how your current program depends only on t[0], t[1], t[-1], but not any of the other values of t. This means that you could change just those values and get a different number of steps -- but you wouldn't have to recompile! This extra flexibility at runtime is responsible for part of the overhead you're measuring.)

As for the changes that I made: most of the overhead turned out to be due to the extra complexity of the recursive checkpointing (as opposed to simply checkpointing on every step). The relavant changes are in https://github.com/patrick-kidger/equinox/pull/415.

This improvement will appear in the next release of Diffrax. And there's clearly still a small discrepancy on the backward pass -- it looks like that still needs some more careful profiling. Let me know what this looks like for your actual use-case!

astanziola commented 1 year ago

Thanks a lot, that looks great! I'm on holidays until the end of next week, but as soon as I'm back I'll give it a try on the actual simulator.

astanziola commented 1 year ago

Sorry if this took a while, but just tested it in the simulator and it works amazingly :smile: I get very good performances (roughly 2.2x for the backward pass). Interestingly enough, I get a very minor performance boost using ConstantStepsize instead of StepTo, but is really something negligible (but consistent).

I do have a couple of warnings at startup due to deprecation of equinox.filter_custom_vjp.defvjp, but that's a different story.

Also... Did I just got a runtime XLA error based on array values ?? What is this new amazing sorcery :heart_eyes: ??

Thanks again for fixing this!

patrick-kidger commented 1 year ago

Marvellous! The slight difference between ConstantStepSize and StepTo -- I have no explanation for this :D

I do have a couple of warnings at startup due to deprecation of equinox.filter_custom_vjp.defvjp, but that's a different story.

Yup, next release of Diffrax (in the next few days) will avoid that code path.

Also... Did I just got a runtime XLA error based on array values ?? What is this new amazing sorcery 😍 ??

Yes you did! Equinox recently added public support for these. Documentation available here. Indeed "sorcery" is the appropriate word, since this is something new under the JAX sun.

patrick-kidger commented 1 year ago

Just released the new version of Diffrax. I think everything discussed here should now work, fast, without warnings.

As such I'm closing this, but please do re-open it if you think this isn't fixed.