Open reubenharry opened 1 month ago
hmmm introducing it might be a pretty big breaking change, especially considering you need an API to input the constraint (as a pytree of function that match the structure of the position). We had done that in TFP using a wrapper class: https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/mcmc/transformed_kernel.py which is a bit difficult to use due to excessive nesting. Handling it through integrator should be nicer potentially.
Does it work with all integrator and upstream sampler (e.g., Leapfrog and NUTS)?
I think the goal would be to work with all integrators, yes, for the sake of uniformity. The goal would be to keep the standard behavior as default, and so not introduce any breaking change. I'll post a prototype here if/when I have one.
Current behavior
Blackjax doesn't have code for constraints.
Stan (and NumPyro) handle constraints by transformation, and one could extract the Numpyro transformations and use those.
However, for periodic or reflective boundary conditions, transformations are not always useful (there isn't any single chart that covers the circle, for example), whereas it is relatively simple to enforce periodicity directly in the integrator.
Desired behavior
I would update the integrators to allow for a hand chosen addition function when updating position, defaulting to the usual (+), but allowing also for modular addition.