pnkraemer / probdiffeq

Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.
https://pnkraemer.github.io/probdiffeq/
MIT License
32 stars 2 forks source link

Speed of initial JITing #456

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

I realise this is a bit of a hand-wavey, but I have noticed that the initial JITing stage can be quite slow compared to when I used exactly the same setup with Diffrax as the ODE solver e.g. 20s for a simple NODE-type setup. I realise this is new library (while Diffrax has been optimised over the course of several years) and once the JIT staging is complete, each iteration in probdiffeq runs very quickly / optimises models in a few seconds, which is awesome.

I just thought I would share my experience of using JAX over the past few years, that slow compilation can often be the symptom of issues I have been not quite been able to put my finger on. However, I have found very innocuous changes, such as exactly where you vmap something or changing the order of a calculation, can sometimes result in significant speed-ups in this process e.g. in one incident, literally just changing at what "level" I have used a vmap resulted in halving that JIT staging.

I also know that Diffrax has used tricks for speeds up JITing e.g. there is a flag for "scan_stage" which runs a particular mode to substantially improve compilation speed.

pnkraemer commented 1 year ago

Thanks for sharing your experience! This kind of feedback is very welcome :)

It is possible that the difference in compilation time is not necessarily due to the maturity of libraries (though it may well be!) but more likely because a single step-attempt with a probabilistic solver involves quite a lot of machinery (preconditioning, extrapolation, linearisation, calibration, correction) even without the acceptance-rejection nature of adaptive step-size selection. So perhaps, the comparably slow compilation is a price one has to pay for using a probabilistic solver instead of a (say) Runge-Kutta method.

(Unfortunately, the scan_stage trick does not apply to what we're doing here.)

That said, I am sure there is at least some potential for improving the compilation time.

Are you aware of any tricks beyond avoiding native-python control flow, i.e., using lax.while_loop instead of a native python loop? You mention something about selecting when to apply vmap. I would love to hear more!

pnkraemer commented 1 year ago

There is also #140, which might relate to your problem (depending on the number of derivatives in your prior and on how complicated your vector field is).

adam-hartshorne commented 1 year ago

You mention something about selecting when to apply vmap. I would love to hear more!

I would love to say I have some hard and fast rules on this issue, but I don't. Very much trial and error, if there is some flexibility in terms of ordering of function calls or level at which vmap can be placed, if JITing is slow, I resort to trying differing options.

e.g. I have just found that from time to time for instance you might have a choice if you vmap an outer function or pass a tensor, X, inside a function, then JAX primiatives act on the whole, then vmap a call to an inner function. And you get a difference. I haven't found a hard and fast rule to that. I think it mostly occurs when you are taking gradients / jvps / vjps in the inner function. Also, I have found from time to time that although vmap is suppose to "auto-magically" handle efficient vectorisation, it can be result in faster / more efficient code if you do it manually.

My guess in some cases, as in described above, is that the JIT compiler fails to see some efficiency and then takes a long time to build a more complicated representation, perhaps making unnecessary copies of elements that aren't required. The problem is the more complex your model gets, the harder is to look at the JIT build and really try to work out what is going on in there.

pnkraemer commented 1 year ago

I see; thanks for elaborating! Sounds like there are some (potentially hard-to-find) performance gains up for grabs :)

Let's take a closer look: I assume you're computing some value_and_grad of a fixed-step solver like in the parameter estimation notebooks. Is that correct? Which solver are you using? Knowing where to look will simplify investigating potential compilation-time improvements.

A minimal example would be fantastic, but knowing the solver configuration would already be a decent starting point (so I/we can check jaxpr etc.).

adam-hartshorne commented 1 year ago

There is also https://github.com/pnkraemer/probdiffeq/issues/140, which might relate to your problem (depending on the number of derivatives in your prior and on how complicated your vector field is).

This appears to be the big factor. Every additional derivative is adding significant additional time to the JITing process, such that the difference between using 1 and 4 results in triple the time to "compile".

I will do a bit of further experimentation and set up a minimal example (or two), with some comments on differences in performance, and email them directly to you in a couple of day's time.

pnkraemer commented 1 year ago

Sounds great! Looking forward to your email.

In the meantime, if you need <= 5 derivatives and the issues in #140 are a blocker, try the Runge-Kutta starter (in probdiffeq.taylor) instead of taylor_mode_fn. This initialisation scheme does not suffer from the problem in #140. The choice between such functions is an argument in the solution routines and for a small number of derivatives, it works just as well.

adam-hartshorne commented 1 year ago

Wondering if you could make use of a similar trick to the one described here.

Improve compilation speed with scan-over-layers https://docs.kidger.site/equinox/tricks/

Obviously, the taylor series expands in length each time around the loop, and jax.lax.scan doesn't allow for differing size returns. Perhaps padding could be used to handle this?

pnkraemer commented 1 year ago

Unfortunately, the trick discussed here does not apply (the problem is a different one, as you write).

But it turns out that there is a way of padding the Taylor coefficients cleverly such that the for loop can be avoided. This has (already) been implemented by #474.

Could you please install the latest version from GitHub and check whether the compilation time still bugs you?

adam-hartshorne commented 1 year ago

Winner.....compilation time is less than half what it was previously on a simple example, using number of derivatives = 4.

pnkraemer commented 1 year ago

Fantastic! Should we close this issue then? Do you consider your problem to be resolved?

adam-hartshorne commented 1 year ago

Just a thought for minor improvement in taylor_mode_fn in taylor.py. For low order expansions, [0, 1], could handle it and return before getting to scan at all. Something like this.

primals= vf(*initial_values)
if order == 0:
    return [primals]
u_1 = jax.jvp(vf, (state, ), (primals,))[1]
# In the case it is first order, just stop here
if order == 1:
    return [primals, u_1]
....
pnkraemer commented 1 year ago

Yes, this could be done (I had this in a prior code version but removed it at some point). But order == 0 (I suppose you mean num == 0) is really nothing to optimise for because, in this case, the function does nothing. order == 1 is indeed a potential early exit, but I am unsure whether it makes any difference because the function scans over a length-0 array (so the scan doesn't do anything). Ultimately, I decided against early exits because it reduces the number of branches the function takes (each of which would need separate test cases).

But I am open to suggestions. Which kind of improvement do you have in mind with this early exit?

adam-hartshorne commented 1 year ago

No that's fine. I was thinking avoiding any need to progress to defining the body_fn (or even managing to avoid the Partial on the vector_field) might save a tiny little bit when order==1.

Now I think about it some more, I am actually 100% sure if you "early exit", all of that "future" code is excluded. The JIT compiler has to know you will never get to that. I might be, given the partial decorator of the function including the static arg definition for the order of expansion in the way you have, but not positive.