Open junpenglao opened 1 year ago
In #392 we define a VIAlgorithm
, and here we would need to define a new base type ParametrizedVIAlgorithm
base type.
@rlouf I think I can take charge of this, but just for sure:
We assume the log_prob_fn
(i.e. log p(x,z) ) in BlackJax takes in a real value flattened array (rather than a dict or something on constrained space) right
Great! No there is no such assumption in the library (or at least shouldn't be), we try to support PyTree states as much as we can.
@rlouf
To follow the design principle of blackjax, I believe VI should also has an API of the form below? :
new_state, info = kernel(rng_key, state)
which would perform one optimization step for the ELBO.
As you can see with the pathfinder implementation, Blackjax treats VI differrently from MCMC algorothms.
The idea is that you first fit an approximation to the target density, and then sample from this approximation with something like (in peudo-code):
approx, info = approximate(rng_key, position)
samples = sample(sample_key, approx, num_samples)
I think at the higher-level the API will always be something more or less like this. We can consider a kernel-like lower interface for some algorithms if it makes sense. But again, I am no VI expert and open to suggestions.
Can someone then, give a minimal working example for the Mean Field VI? This would be helpful also for the refactoring of the pathfinder API in #465 and the implementation of the full rank approach, i believe.
MFVI is implemented here and full ranks is being implemented in https://github.com/blackjax-devs/blackjax/pull/479. The refactoring of Pathfinder is a bit involved, but up for grabs :)
I understand. Altough i would argue that it would be helpful to get a foot in the door, if one wants to help to develop VI further. Could be as easy as having a multivariate normal and evaluating mean field and full rank. For the current implementation i don't see immediately how the pseudo-code you provided is implemented in the library.
@LarsKarbach I understand your point and I really wish there could be a template
for implementing VI variants (e.g. as simple as providing a log_q function and a sampling function) but the APIs are still in the very initial stage. At this moment, there are still lots of boilerplate code in the implementation... Probably after the fullrank VI's PR got merged in, we could start working on simplifying the VI implementation process.
Copying over from https://github.com/blackjax-devs/blackjax/pull/392#discussion_r1020745315
After #392, we should add the 2 most basic VI algorithm: meanfield and full rank ADVI [1]. Below is a working example of Meanfield ADVI:
Fitting a model looks like:
[1] https://arxiv.org/abs/1603.00788