jeremiecoullon / SGMCMCJax

Lightweight library of stochastic gradient MCMC algorithms written in JAX.
https://sgmcmcjax.readthedocs.io/en/latest/index.html
Apache License 2.0
95 stars 8 forks source link

compiled progress bar doesn't work on TPUs #24

Open jeremiecoullon opened 3 years ago

jeremiecoullon commented 3 years ago

This colab has some of the examples running on TPU, but the progress bar causes a weird bug to do with the host_callback module.

As a fix I added an argument to turn off the progress bar (set to True by default):

my_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size, pbar=False)

Can we get the progress bar to work on TPU?