NNPDF / mcpdf

2 stars 0 forks source link

HMC #4

Open alecandido opened 2 years ago

alecandido commented 2 years ago

Here the roadmap for HMC implementation:

alecandido commented 2 years ago

This is just for myself @AleCandido, same considerations of https://github.com/AleCandido/mcpdf/issues/2#issuecomment-1110670683

Gattocrucco commented 2 years ago

The problem with pymc3 is that we want to take second derivatives, and then pymc3 takes the gradient on top of that. pymc4 supports JAX as backend so maybe (not sure) you can do it easily, in pymc3 you would need to code all derivatives manually, and redo it if you change the model (see slide 24 in my seminar on lsqfitgp).

Currently in the tests with lsqfitgp I'm doing a Laplace approximation for the nonlinearities and the hyperparameters, after appropriately transforming the hyperparameters (log for positive, etc.). This often works well, considering in particular that we are not interested in the hyperparameters per se, we only use them as a way of specifying a flexible prior distribution of the PDFs, we care about the predictive error and not about getting right the tails of the posteriors of the hyperparameters.

Moreover I've seen the error on the current fitted PDFs and it's small, so overall I think the fit would work without MCMC.

If we end up really needing it, we could first test the fit with lsqfitgp where it's easy to change the model and then hardcode everything in pymc3 when we are sure of what kernels we want to use. Other alternatives are using JAX-based beta software (numpyro, pymc4, tinygp) and do some stuff on our own but not computing all derivatives, but considering that lsqfitgp is written with autograd I could as well port lsqfitgp to JAX and then plug its marginal likelihood into any NUTS implementation.

alecandido commented 2 years ago

Ok, the idea is to have a NUTS-based implementation, about the details of the library providing it I don't have any preference.

I guess the proof of concept is worth the effort, even if we end up choosing pure lsqfitgp implementation at the end (meaning kriging, and not HMC). I had a look at numpyro, as you know, but it looked me bloated. Likewise, I'm aware that even pymc3 can be bloated as well, as in your slides.

I'm taking the burden of providing one or the other implementation, but if you can help me, I'd be glad to accept your insights (and even practical help).

Furthermore, I had a look at the implementation of NUTS in pymc3, and it's rather concise:

thus even an implementation from scratch should not be hard (even if I'd study it and the abstract algorithm before, and cook up my own implementation in my own optimized/favorite way). If we can plug it in lsqfitgp, so much the better.