Open Mv77 opened 2 years ago
Thanks for the request! We would definitely welcome a contribution of this functionality.
FWIW a general MVN CDF is pretty difficult to compute (see this TFP issue for reference). A Monte Carlo approximation (as suggested in that issue) could be a good approximation if you need something right now.
Alternatively if you know your MVN has a diagonal covariance, you can use a product of independent normal CDFs instead.
@sharadmv @jakevdp I am interested in this ticket but I would like to have more clarity on how I should implement this.
From scipy source code, I see that jax.random.multivariate_normal
internally triggers helper functions from scipy.stats.mvndst.f
(it's not written in Python) for monte carlo sampling, and the script (mvndst.f
) is an implementation of this paper.
Should I re-implement the script in JAX?
Hi,
I have implemented this code for the two-variate case. It is based on the recent paper: " A simple approximation for the bivariate normal integral" Wen-Jen Tsay & Peng-Hsuan Ke https://www.tandfonline.com/doi/abs/10.1080/03610918.2021.1884718
https://colab.research.google.com/drive/1w2tI1-1LWzPSdG_jE0FXzdrs6VAsJwhv?usp=sharing
Posting it here in case someone also needs this CDF. It is jax.grad friendly. With this code I have replicated the authors' paper, thus I believe it's accurate enough.
Hi,
Some of the distributions from
scipy
that you have implemented (laplace, logistic, norm) support their.cdf
method for evaluating the cumulative distribution function. One that does not isjax.random.multivariate_normal
. It would be great if you could add support for the multivariate normal CDF. The here is the doc of scipy's multivariate_normal.Please: