rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Diagnostics with many chains #26

Closed rlouf closed 3 years ago

rlouf commented 4 years ago

I am opening this after seeing this tweet by Junpeng Lao to start a discussion about how to display diagnostics when we sample many chains, especially divergences. When we only have a few chains it is possible to give numbers for each chain and let the users make sense of them. What about when there are 1,000 chains? Is there a way to extract information that is:

  1. Understandable
  2. Actionable

Junpeng's visualization gives a nice quick overview: too much yellow (I assume divergences are in yellow) and there is something wrong. But how wrong is wrong? Is there a threshold of acceptable number of divergences? How can we do something about this?

PS: We could plot distributions in the terminal using eg gnuplotlib.

rlouf commented 4 years ago

When there are many chains we can probably use rank order statistics by splitting the number of chains in two and test for uniform distribution.

rlouf commented 3 years ago

See https://twitter.com/remilouf/status/1330424234733080577

Divergences and posterior predictive checks are out two main tools. In addition (maybe) to what I said above.

rlouf commented 3 years ago

Moved to Discussions.