blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Improvements to `run_inference_algorithm` #670

Closed reubenharry closed 4 months ago

reubenharry commented 4 months ago

Current behavior

  1. run_inference_algorithm can optionally take a position or an initial state. The try-except handler is a little unreliable (e.g. except not thrown in presence of other exceptions). Moreover, it seems more modular to delegate the transformation from initial_position to initial_state to the caller of run_inference_algorithm.

  2. run_inference_algorithm produces n samples. For high dimensional problems this is memory inefficient.

  3. transform currently only applies to state and not Info, so there isn't a way to dispense of a part of the diagnostic info.

    Desired behavior

  4. Make run_inference_algorithm only take initial_state

  5. Allow run_inference_algorithm to have a memory-efficient mode, where it computes a running average of a desired expectation.

  6. Change transform to take also Info as argument. This will be a breaking change (more so than (1))