blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
821 stars 106 forks source link

Constraints via the integrator #742

Open reubenharry opened 1 month ago

reubenharry commented 1 month ago

Current behavior

Blackjax doesn't have code for constraints.

Stan (and NumPyro) handle constraints by transformation, and one could extract the Numpyro transformations and use those.

However, for periodic or reflective boundary conditions, transformations are not always useful (there isn't any single chart that covers the circle, for example), whereas it is relatively simple to enforce periodicity directly in the integrator.

Desired behavior

I would update the integrators to allow for a hand chosen addition function when updating position, defaulting to the usual (+), but allowing also for modular addition.

junpenglao commented 1 month ago

hmmm introducing it might be a pretty big breaking change, especially considering you need an API to input the constraint (as a pytree of function that match the structure of the position). We had done that in TFP using a wrapper class: https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/mcmc/transformed_kernel.py which is a bit difficult to use due to excessive nesting. Handling it through integrator should be nicer potentially.

Does it work with all integrator and upstream sampler (e.g., Leapfrog and NUTS)?

reubenharry commented 1 month ago

I think the goal would be to work with all integrators, yes, for the sake of uniformity. The goal would be to keep the standard behavior as default, and so not introduce any breaking change. I'll post a prototype here if/when I have one.