eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
69 stars 16 forks source link

Using a lax.scan to run the solver #11

Open EiffL opened 1 year ago

EiffL commented 1 year ago

This draft PR is in response to #9 and presents a prototype implementation of the leap-frog solver that uses a lax.scan instead of a for loop in the nbody function.

Here are the results on the baseline default configuration.

Current master

This PR (using scan, and I actually removed all lower level jit)

And here the notebook to reproduce this test (working off my fork): https://gist.github.com/EiffL/aa6a651141f694ca257fb5ff83e829d6

So I would advocate using lax.scan.

In this draft implementation, I chose not to output intermediate ptcl and obsvl, exactly like what is done on master, but if you want to export intermediate snapshots, it's easy you can export them as the output of the scan fn :-)

If you look at the implementation of odeint in jax, you can also have a slightly more complicated logic that exports the state of the system only at some desired pre-defined steps, and not necessarily at all time steps: https://github.com/google/jax/blob/518fe6656ca2aab66dcfc8cd7866c10f476a17b1/jax/experimental/ode.py#L189

And finally, if you want to save the sims to disk, then nothing prevents you from using the nbody step function directly/manually in a for loop.

EiffL commented 1 year ago

And actually ^^ it's generally not a good idea ^^' but if we want to, we can definitely write a custom CPU op that will dump the simulation in hdf5 from within jitted code, and from within the lax.scan.

In this particular instance, I think it would be pretty cool

EiffL commented 1 year ago

So yeah I don't see any drawbacks of using a scan :-)

eelregit commented 1 year ago

Thanks! This was very much how it was done here. Also here for the adjoint.

So it's good to know that XLA or JAX has gotten better on this.

but if you want to export intermediate snapshots, it's easy you can export them as the output of the scan fn :-)

I guess you meant nested scan's. We want interpolation between two steps. It looks like odeint is extrapolating from the last step? But interpolation should also be okay with nested scan's.

EiffL commented 1 year ago

In odeint they use a while inside the scan function yes.

Would you be ok with an API with an argument which would be the array a to use in the solver, and maybe another optional array save_at which would contain the indices of the snapshots to export. By default it would be [-1]. If so I'm happy to implement it :-)

And then, I think it would be very cool to have the ability to do IO directly from jitted code :-) And I think I know how to do it, but probably that's for a different PR.

eelregit commented 1 year ago

Let's try switching to scan following the odeint way, once the checkpoint (exactly at a time step, directly copying disp and vel) and snapshot (interpolation between 2 steps) observables are implemented. @Yucheng-Zhang is working on those observables.

Yes, it'd be super cool to have a custom IO op ^^

eelregit commented 1 year ago

id_tap seems to be useful in writing snapshots