blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Improve SamplingAlgorithm design for init_fn and step_fn #619

Closed junpenglao closed 9 months ago

junpenglao commented 10 months ago

High level user API is defined as:

https://github.com/blackjax-devs/blackjax/blob/08e0d7521b2c06ba29b94a1c186bdaf8c08e6310/blackjax/base.py#L88

Where initialization and update has a function signature as follow: https://github.com/blackjax-devs/blackjax/blob/08e0d7521b2c06ba29b94a1c186bdaf8c08e6310/blackjax/base.py#L37 https://github.com/blackjax-devs/blackjax/blob/08e0d7521b2c06ba29b94a1c186bdaf8c08e6310/blackjax/base.py#L67

However, there are a few sampler does not follow these pattern, we should either extend the function signature, or make sure the function signature is applied.

cc @reubenharry

reubenharry commented 10 months ago

In general, it would be nicer not to change the function signature (not least because things like run_inference_loop can then use it). The only way to do this, as far as I can see, is to have default values in the init function when needed. This seems fine to me, at least for some cases.