blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
853 stars 107 forks source link

Doc: add Stan PPL integration? #718

Open gil2rok opened 3 months ago

gil2rok commented 3 months ago

Add Stan PPL integration to use Stan models with Blackjax inference algorithms

With the BridgeStan library, we can efficiently access log density and gradient evaluations of Stan models. Following the Blackjax documentation, we can then use BridgeStan with custom gradients and JAX callbacks. Note that Stan, not JAX, would be computing gradients that is then used by Blackjax algorithms.

For a somewhat complete example see my small example repo and another example here. Would be open to attempting a PR if there is interest.

junpenglao commented 3 months ago

Oh that's a great idea - +1 to appending a more complicated example to custom gradients.

ciguaran commented 3 months ago

Looking forward to seeing this!

reubenharry commented 2 months ago

I'd also be interested in using this!