Closed ben-cassese closed 3 weeks ago
Months ago I was stuck in a frustrating spot where the model (light curve and loglike) could evaluate quickly, but then the gradient of the inputs would take too long to compile for practical use. That was fine since I was mostly using nested sampling routines that didn't need the gradients, so I didn't dig into it. But somewhere along the way something changed that lets jax.jacfwd
compile and evaluate quickly, though jax.jacrev
still takes several minutes to trace for models with ~500 observations and often returns NaNs after that.
Since an eventual goal is to use packages like BlackJax that rely on jax.grad (which uses reverse mode), I have now defined @jax.custom_vjp's which just call jax.jacfwd. I feel like there's a better way to do this and am still not sure what's making the reverse mode so unhappy, but I'm at least excited to have access to the gradients now. It's weird enough though that I wanted to leave a note somewhere outside the comments.
Been putting this off...