ben-cassese / squishyplanet

Transits of non-spherical exoplanets
https://squishyplanet.readthedocs.io/en/latest/
MIT License
4 stars 2 forks source link

Speed up the gradients #24

Closed ben-cassese closed 3 weeks ago

ben-cassese commented 3 weeks ago

Been putting this off...

ben-cassese commented 3 weeks ago

Months ago I was stuck in a frustrating spot where the model (light curve and loglike) could evaluate quickly, but then the gradient of the inputs would take too long to compile for practical use. That was fine since I was mostly using nested sampling routines that didn't need the gradients, so I didn't dig into it. But somewhere along the way something changed that lets jax.jacfwd compile and evaluate quickly, though jax.jacrev still takes several minutes to trace for models with ~500 observations and often returns NaNs after that.

Since an eventual goal is to use packages like BlackJax that rely on jax.grad (which uses reverse mode), I have now defined @jax.custom_vjp's which just call jax.jacfwd. I feel like there's a better way to do this and am still not sure what's making the reverse mode so unhappy, but I'm at least excited to have access to the gradients now. It's weird enough though that I wanted to leave a note somewhere outside the comments.