google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

Implement scipy.stats.multivariate_normal.cdf #10562

Open Mv77 opened 2 years ago

Mv77 commented 2 years ago

Hi,

Some of the distributions from scipy that you have implemented (laplace, logistic, norm) support their .cdf method for evaluating the cumulative distribution function. One that does not is jax.random.multivariate_normal. It would be great if you could add support for the multivariate normal CDF. The here is the doc of scipy's multivariate_normal.

Please:

jakevdp commented 2 years ago

Thanks for the request! We would definitely welcome a contribution of this functionality.

sharadmv commented 2 years ago

FWIW a general MVN CDF is pretty difficult to compute (see this TFP issue for reference). A Monte Carlo approximation (as suggested in that issue) could be a good approximation if you need something right now.

Alternatively if you know your MVN has a diagonal covariance, you can use a product of independent normal CDFs instead.

riven314 commented 2 years ago

@sharadmv @jakevdp I am interested in this ticket but I would like to have more clarity on how I should implement this.

From scipy source code, I see that jax.random.multivariate_normal internally triggers helper functions from scipy.stats.mvndst.f (it's not written in Python) for monte carlo sampling, and the script (mvndst.f) is an implementation of this paper.

Should I re-implement the script in JAX?

flaviovdf commented 1 year ago

Hi,

I have implemented this code for the two-variate case. It is based on the recent paper: " A simple approximation for the bivariate normal integral" Wen-Jen Tsay & Peng-Hsuan Ke https://www.tandfonline.com/doi/abs/10.1080/03610918.2021.1884718

https://colab.research.google.com/drive/1w2tI1-1LWzPSdG_jE0FXzdrs6VAsJwhv?usp=sharing

Posting it here in case someone also needs this CDF. It is jax.grad friendly. With this code I have replicated the authors' paper, thus I believe it's accurate enough.