pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
77 stars 49 forks source link

Adjoint method to find the gradient of the Laplace approximation/mode #343

Open theorashid opened 4 months ago

theorashid commented 4 months ago

This is part of INLA roadmap #340.

From the Stan paper:

One of the main bottlenecks is differentiating the estimated mode, $\theta^ $. In theory, it is straightforward to apply automatic differentiation, by bruteforce propagating derivatives through $\theta^ $, that is, sequentially differentiating the iterations of a numerical optimizer, But this approach, termed the direct method, is prohibitively expensive. A much faster alternative is to use the implicit function theorem. Given any accurate numerical solver, we can always use the implicit function theorem to get derivatives. One side effect is that the numerical optimizer is treated as a black box. By contrast, Rasmussen and Williams [34] define a bespoke Newton method to compute $\theta^* $, meaning we can store relevant variables from the final Newton step when computing derivatives. In our experience, this leads to important computational savings. But overall this method is much less flexible, working well only when the number of hyperparameters is low dimensional and requiring the user to pass the tensor of derivatives.

I think the jax implementation uses the tensor of derivatives but not 100% sure.

theorashid commented 2 months ago

The rewrites in optimistix might be helpful here to understand what is going on. Also see the docs.

theorashid commented 2 months ago

Some notes on how they use this in Stan.

theorashid commented 2 months ago

An example of this for the fixed point optimiser in jax.