jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

Add directional distributions #11737

Open carlosgmartin opened 2 years ago

carlosgmartin commented 2 years ago

Add the following directional distributions to jax.random (for sampling) and/or jax.scipy.stats (for probability density function evaluation):

huangziwei commented 1 year ago

jax.scipy.stats.vonmises currently (jax + jaxlib == 0.4.2) doesn't take loc and scale as arguments. Is it intended?