eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
69 stars 16 forks source link

Outer JIT compilation time could be optimized #9

Open EiffL opened 1 year ago

EiffL commented 1 year ago

In this example: https://gist.github.com/EiffL/8e46d261e5d52cd28ca81e233fef9b04

It takes 3 mins for the first evaluation of the model to run, but just a few seconds in the second run.

@modichirag has also been able to check that the compilation time is a function of the number of steps. This would indicate that the code is building an overly complex computational graph including explicitly each step of the nbody.

I suspect this is due to using a python for loop in the nbody function. Probably things would improve a lot if it were replaced with a lax.scan

eelregit commented 1 year ago

You can remove the jit in the gist.

We have discussed this. This is part of the reason that pmwd already does jit for you.

It was a scan before but that was slower than a for loop. Maybe you can try again and see if it gets faster. With the current implementation, we can easily save a snapshot between two jitted steps. IO is side effect. Not sure how it can be done with jit and scan.

The down side is that the backprops of growth, modes, and lpt are not jitted in this approach. There's a workaround to split those and the following nbody into part1 and 2 of the model, the but it's a bit cumbersome.

eelregit commented 1 year ago

The quickstart notebook has some timing.

EiffL commented 1 year ago

Oh right right, well the problem is that ultimately there will need to be a jit outside of pmwd. For instance in an hmc sampler or as part a of a larger simulation model that also includes additional computational layers. And concretely for me it's good to have my distributed code inside a jit (not required though).

Ok I'll see if I can think of how to do the export with a scan. But I can't imagine it being that much slower? Can you remember by how much it changed things?

eelregit commented 1 year ago

I don't see why HMC requires a top-level jit. Maybe it's more convenient that way. But why mandatory?

And what do you have in mind for the the larger model? If it's a really big model like almost any NNs, top-level jitting would see this problem again even if pmwd is leaner?

I remember it was 20-30%. And I think one cannot export inside jit.

eelregit commented 1 year ago

I wonder if it's still slow if you jit nbody_step first, and then the whole model function.

EiffL commented 1 year ago

When you use an external sampler like numpyro or TFP, it will usually compile all the logic of the hmc kernel, including the evaluation of the log likelihood. It may be possible to disable that (and I agree it's not in principle necessary), but by default in JAX the user expects to be able to jit their code without knowledge of the underlying implementation.

20/30% sounds like a lot yeahhh.... I guess it was for small size problems though, but still I see the reason for this tradeoff if that is that bad.

And otherwise, yeah I agree that saving snapshots to disk from within a jitted function wouldn't be super trivial. But if you are doing things on the fly that's probably not super important. I can see though that maybe you want to avoid the memory cost of storing intermediate snapshots....

eelregit commented 1 year ago

Some of those users are likely already used to the compilation speed. I have heard complaints about JAX taking minutes to compile NNs.

I don't know why JAX cannot get cache hits on nbody_step after unrolling the for loop and compiling nbody_step the first time. Here's a related issue: https://github.com/google/jax/issues/284

eelregit commented 1 year ago

A discussion that mentions adding lower level jit helps with compilation time https://github.com/google/jax/discussions/10104

eelregit commented 1 year ago

I think saving snapshots are also important for normal use cases, like generating mocks.