rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Added multivariate normal distribution #68

Closed sidravi1 closed 3 years ago

sidravi1 commented 3 years ago

Overview

Add MvNormal distribution. Uses jax.scipy.multivariate_normal for logpdf and jax.random.multivariate_normal for sampling.

Quick tests

>>> import jax
>>> import jax.numpy as np
>>> import mcx
>>>
>>> rng_key = jax.random.PRNGKey(0)
>>> mvn = mcx.distributions.MvNormal(np.array([1, 2]), np.array([[1.0, 0.2], [0.2, 1.0]]))
>>> mvn.sample(rng_key)
DeviceArray([0.21523398, 2.6821878 ], dtype=float32)
>>> mvn.sample(rng_key, (3,))
DeviceArray([[1.1878438 , 0.78015494],
             [1.649418  , 3.3537068 ],
             [1.2444699 , 1.9338173 ]], dtype=float32)
>>> mvn.logpdf(np.array([1, 2]))
DeviceArray(-1.817466, dtype=float32)
>>> mvn.logpdf(np.array([[1, 2], [2, 1]]))
DeviceArray([-1.817466, -3.067466], dtype=float32)

Not sure the batching is broadcasting correctly though:

>>> mvn2 = mcx.distributions.MvNormal(
...     np.array([[1, 2], [3, 2]]), np.array([[1.0, 0.2], [0.2, 1.0]])
... )
>>> mvn2.sample(rng_key)
DeviceArray([[2.8160858, 1.6235838],
             [3.339889 , 1.5439482]], dtype=float32)
>>> mvn2.sample(rng_key, (3, 2))
DeviceArray([[[[ 1.3675394 ,  1.183653  ],
               [ 0.99355793,  1.7560301 ]],

              [[ 1.1323344 ,  0.7474071 ],
               [ 2.5944324 ,  0.161587  ]]],

             [[[-0.35665548,  2.5218964 ],
               [ 2.6202202 ,  2.0067666 ]],

              [[-0.89568627,  1.4151701 ],
               [ 3.20252   ,  3.3841357 ]]],

             [[[ 0.3996758 ,  0.86409783],
               [ 4.5410695 ,  2.3596075 ]],

              [[ 1.0302643 ,  3.2970448 ],
               [ 3.6156623 ,  3.787726  ]]]], dtype=float32)
>>> mvn.logpdf(np.array([1, 2]))
DeviceArray(-1.817466, dtype=float32)
>>> mvn.logpdf(np.array([[1, 2], [2, 1]]))
DeviceArray([-1.817466, -3.067466], dtype=float32)

Pending

Proper test cases Fixing shape issues

sidravi1 commented 3 years ago

I think I'm mostly there with shape. Will run a bunch of tests tomorrow and if it's all kosher, I'll put it up for another review.

sidravi1 commented 3 years ago

Added some test cases. I think shape stuff is looking ok now.

Batching when using jax.scipy's multivariate_normal's logp gives infs. Not sure if I'm doing something stupid. I raised this.

rlouf commented 3 years ago

Patched, now the tests pass!

sidravi1 commented 3 years ago

Made the fixes!

Re: the jax bug. The tests are not affected since they are only checking shape. The jax bug leads to inf being returned but the shape is correct.

Btw... Do you want me to squash all my merges so the commit history doesn't get messy?

rlouf commented 3 years ago

Re: the jax bug. The tests are not affected since they are only checking shape. The jax bug leads to inf being returned but the shape is correct.

Ok, hopefully it will be fixed before the first release.

Btw... Do you want me to squash all my merges so the commit history doesn't get messy?

Yes please!

rlouf commented 3 years ago

Looks good to me, merging. Great job!

rlouf commented 3 years ago

Addressed one item in #65