tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

TFP JAX: The transition kernel drastically decreases speed. #1807

Open SebastianSosa opened 2 months ago

SebastianSosa commented 2 months ago

Dear all,

I am currently learning Bayesian analysis and utilizing tensorflow_probability.substrates.jax, but I've encountered some issues. While using jax with jit for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here's a summary of the time taken:

I've conducted speed tests comparing with Numpypro, and essentially, Numpypro with dual averaging step size adaptation and parameter constraints is equivalent to tensorflow_probability NUTS alone.

Could there be something I've missed? Is there room for optimization in this process?

Please find the data and code (.txt need to be change as .ipynb) for reproducibility enclosed: data.csv gitissue.txt google Colab

Please note that I'm only using the first 100 lines of the data.

Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)

Thank you in advance for your assistance.

Sebastian

ColCarroll commented 1 month ago

Hi - It looks like the colab is locked down, so I can not access it.

SebastianSosa commented 1 month ago

Does this link allow access?

I made a simulation instead of using real data, as it allows us to evaluate how the models perform with the increase in data size. I can update it in the next few days.

ColCarroll commented 1 month ago

Note that the data is not saved with the colab, so I can not run this, but it looks as though the problem is with your use of tfp.bijectors.CorrelationCholesky(ni). Note that CorrelationCholesky doesn't take any parameters, and ni is silently being accepted as an argument to validate_args.

Downstream, I think this will lead to some wild posterior, and so TFP NUTS is (correctly) exhausting its tree doublings and doing ~10x as much work.