patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 123 forks source link

Saving metrics during solve #113

Closed joglekara closed 1 year ago

joglekara commented 2 years ago

Hi, thanks again for the library.

we have a numerical integration we perform using lax.scan. We'd like to port it over to diffrax. I wanted to get your thoughts on the shortest path to this implementation.

From https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

we rely quite a bit on the implicit append and stack calls that happen inside lax.scan. I think the core of my question is, is there a way to avoid the pre-allocation as you suggest in https://github.com/patrick-kidger/diffrax/issues/60#issuecomment-1034768386 . My guess is the answer is no but figured I'd see if you had any ideas.

TY in advance!

patrick-kidger commented 2 years ago

Do I understand that:

  1. As in #60, you're interested in logging metrics during the solve, and want to save something at every step?
  2. As you're currently using a lax.scan, you're probably using a fixed step size rather than an adaptive step size?

If this is the case then this is totally doable. Follow the approach as in #60 to create a wrapper solver that stores your metrics. metric_state should be a buffer of the appropriate size (equal to the number of steps you'll be taking), which you update using "in-place updates". (buffer = buffer.at[step].set(value).) Now the analogous behaviour to a lax.scan in Diffrax is to pass stepsize_controller=ConstantStepSize(compile_steps=True), and if you do this then you should find that your "in-place updates" really do get complied down to efficient in-place updates!

patrick-kidger commented 2 years ago

Can you give me some examples of the sort if thing you're trying to save, by the way? That this has come up twice now suggests that it may be worth creating an easy way to save user-specified metrics. (And in particular without needing to worry about this business of the efficiency of in-place updates.)

joglekara commented 2 years ago

Do I understand that:

  1. As in Logging metrics during an ODE solve #60, you're interested in logging metrics during the solve, and want to save something at every step?
  2. As you're currently using a lax.scan, you're probably using a fixed step size rather than an adaptive step size?

If this is the case then this is totally doable. Follow the approach as in #60 to create a wrapper solver that stores your metrics. metric_state should be a buffer of the appropriate size (equal to the number of steps you'll be taking), which you update using "in-place updates". (buffer = buffer.at[step].set(value).) Now the analogous behaviour to a lax.scan in Diffrax is to pass stepsize_controller=ConstantStepSize(compile_steps=True), and if you do this then you should find that your "in-place updates" really do get complied down to efficient in-place updates!

Yes to both!

Okay sounds good. I was just hoping that I wouldn't have to preallocate metric_state bc, well, I don't do it now and therefore, it's more work :P

Can you give me some examples of the sort if thing you're trying to save, by the way? That this has come up twice now suggests that it may be worth creating an easy way to save user-specified metrics. (And in particular without needing to worry about this business of the efficiency of in-place updates.)

Yes, of course.

I am solving a plasma physics PDE where it's impractical, if not impossible (due to speed and storage), to store the full state at every timestep. So, we calculate some moments, interpolations, averages etc. and track those over time. See fig 5 and 6 for examples of reduced versions of the full state over time.

This is definitely the case for other PDEs in plasma physics. I am assuming the same happens in PDE solves in other domains.

patrick-kidger commented 2 years ago

So given the state y (at each timestep), then at the moment, with SaveAt(steps=True), we simply save y. Perhaps what could be done is to allow inserting a user-specified function (t, y) -> anything here that determines what to save, defaulting to

def just_state(t: Scalar, y: Array):
    return y

Perhaps it could/should also be a function of solver_state etc. etc.

joglekara commented 2 years ago

That's exactly what I do right now in the scan (in a Haiku module, I know, I know, transform magic, i'm certainly thinking about equinox)

def one_step(y, current_params):
  y = evolve_PDE(y)
  temp_storage = self.storage_step(y)
  return y, temp_storage

self.storage_step is the postprocessing step

patrick-kidger commented 2 years ago

Sounds good! I've adjusted the title of this issue to make this a feature request. (Which if you ever feel like digging into the guts of Diffrax I'd be happy to shepard a PR on.)

joglekara commented 2 years ago

i am in, is this where we'll be working?

https://github.com/patrick-kidger/diffrax/blob/cec091c5e4cc4311f64ae3aa09a371db5fe766ee/diffrax/integrate.py#L246

patrick-kidger commented 2 years ago

Yep! It's this big block of code. The idea will be to change from state.ys -- which is where we save our output -- to state.out.

Now the bad news is that this part of the code is pretty complicated to read, as it needs to carefully make sure it gets good performance. Fortunately the good news is that you shouldn't need to change any of that -- just carefully figure out what's what, and replace saving ys with saving save_fn(ys).

I'd suggest that the saving-function should be specified as part of the diffeqsolve(..., saveat=...) argument, and that when called it should take state as an argument.

As a nice bonus, I've realised that this makes the solver_state, controller_state, made_jump arguments to SaveAt superfluous: these can be obtained simply by using the appropriate metric-saving function.

joglekara commented 1 year ago

Looking at this again, do I want to trace back the ys=ys line to be more of a out=out line?

        new_state = _State(
            y=y,
            tprev=tprev,
            tnext=tnext,
            made_jump=made_jump,
            solver_state=solver_state,
            controller_state=controller_state,
            result=result,
            num_steps=num_steps,
            num_accepted_steps=num_accepted_steps,
            num_rejected_steps=num_rejected_steps,
            saveat_ts_index=saveat_ts_index,
            ts=ts,
            ys=ys,
            save_index=save_index,
            dense_ts=dense_ts,
            dense_infos=dense_infos,
            dense_save_index=dense_save_index,
        )
patrick-kidger commented 1 year ago

So given that the idea here is to save these summary statistics instead of y: we could always just use the ys buffer itself.

Then , at all the places where we currently do ys.at[index].set(y), we instead do ys.at[index].set(save(t, y, args)), where save is a function passed to diffeqsolve, which defaults to save = lambda t, y, args: y.

patrick-kidger commented 1 year ago

Closing as this has been completed in the about-to-be-released v0.3.0.

For the original PDE use-case, you may find the new SubSaveAt functionality useful, in that it allows you to save the terminal value in its entirety, and just the statistics of the evolving solution. See the examples on that page.

I'd also welcome any feedback on the new nonlinear heat PDE example.

joglekara commented 1 year ago

Thanks for helping me get this through. SubSaveAt is a nice feature add too!

I think I'm on to a bug, but I don't have a minimum repro for this yet so I thought I'd see if something comes to your mind right away -- I'm getting NaNs when trying to take a gradient of one of these reduced quantities. The gradient comes through fine when using a fn=None with a ts=. I'm only using jnp.interp in my fn so I'd like to say it should be diff'able.

Any thoughts? I'll try to get a min repro going in the meantime...

joglekara commented 1 year ago

I think I have this traced down to storing abs and angle of complex values. This happens even if I take the gradient of a different quantity that never undergoes a complex transformation. So it's probably related to the complex support issue in https://github.com/patrick-kidger/diffrax/issues/96