blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
846 stars 106 forks source link

We should be able to distribute gradient computations #240

Open rlouf opened 2 years ago

rlouf commented 2 years ago

In particular the implementation should be general enough that we can shard a large dataset on several machines, compute the partial gradient on each machine and combine its value before making a leapfrog step, as in https://arxiv.org/pdf/2104.14421.pdf

ludgerpaehler commented 1 year ago

I am happy to pick this up!

junpenglao commented 1 year ago

Related: https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX