blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Integration with Inference Gym #634

Closed reubenharry closed 4 months ago

reubenharry commented 8 months ago

Since Blackjax is a repository of inference algorithms with a fairly uniform interface, it would be nice if there was a fairly automated procedure for benchmarking and comparing different algorithms, using the models from inference gym.

What I'm envisioning is a helper function that would take a set of SamplingAlgorithms, let's say, run each on all of the inference gym problems, and then report some useful metrics, e.g. bias vs wallclock time (or number of gradient calls), plotted in a graph, so that it's as easy as possible to assess the performance of a given sampling algorithm.

Does something like this sound of interest? Has it already been done? If yes to the former, and no to the latter, it's something I'd be interested in contributing (a similar bit of code is in https://github.com/JakobRobnik/MicroCanonicalHMC/tree/master and could be ported over).

Why have this in Blackjax

Follow up questions

Is inference gym available in jax? https://github.com/JakobRobnik/MicroCanonicalHMC manually ports it, but that wouldn't stay up to date if inference-gym were changed.

ColCarroll commented 8 months ago

Note you should be able to use

from inference_gym import using_jax as gym

via this test

junpenglao commented 8 months ago

+1 to the suggestion. I think we should set up a new repository for that.

reubenharry commented 8 months ago

OK, I'll look into that. For now I might proceed by working in this repo, and then we can discuss splitting it out down the road based on the PR.

Basically what I'm thinking of is porting @JakobRobnik's code (https://github.com/JakobRobnik/MicroCanonicalHMC/blob/master/benchmarks/error.py), and making a function assess_sampling_algorithm, that gives ess per sample across the different models. Nothing too fancy.

One reason to do this is that currently in Blackjax, it's not totally clear where to look to find the optimal strategy for both tuning and running a given algorithm. So for example, it would be nice for this to serve as a place to see how to run each algorithm with the best possible tuning.

ColCarroll commented 8 months ago

So for example, it would be nice for this to serve as a place to see how to run each algorithm with the best possible tuning

For what it is worth, I'm very interested in this for https://github.com/jax-ml/bayeux, which follows some heuristics, but would welcome specifics.

reubenharry commented 8 months ago

OK, having investigated a little more, some notes (for myself, mainly). What I want is a function type InferenceGymModel -> ArrayOfSamples, i.e. a function that takes an inference gym model and produces samples (of the appropriate dimensions, although Python's type system is too puny to enforce this), with the recommended adaptation, preconditioning, etc etc.

Then I want each of the sampling algorithms that is going be benchmarked in blackjax to have a corresponding function of this type, in a example_usage directory or similar.

Then I'll have a function (InferenceGymModel, ArrayOfSamples) -> ESS_Estimate (already written), and it will be straightforward to compute ESS for each pair of inference method and model, which I will plot in a pretty graph (ideally run by CI).

It's possible this overlaps a little with Bayeux, in which case maybe there's some common ground to exploit.

reubenharry commented 8 months ago

Some todos ad Divij and I work on this: