jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.34k stars 2.78k forks source link

random.gnormal - Generalized normal distribution #10875

Closed carlosgmartin closed 2 years ago

carlosgmartin commented 2 years ago

Add a function that samples from the generalized normal distribution. (See scipy.stats.gennorm.) Example implementation:

from jax import random, numpy as jnp
from jax.lax import lgamma
from jax.scipy.special import gammainc
from scipy.stats import kstest

def gnormal(key, p, shape=(), dtype=None):
  """Sample from the generalized normal distribution.

  Args:
    key: a PRNG key used as the random key.
    p: a float representing the shape parameter.
    shape: a sequence of nonnegative integers representing the result shape.
    dtype: a float dtype representing the result dtype.

  Returns:
    A random array with the specified shape and dtype.
  """
  keys = random.split(key)
  g = random.gamma(keys[0], 1/p, shape, dtype)
  r = random.rademacher(keys[1], shape, dtype)
  return r * g ** (1 / p)

def gamma(x):
  return jnp.exp(lgamma(x))

def gnormal_pdf(x, p):
  return p / (2 * gamma(1 / p)) * jnp.exp(-jnp.abs(x) ** p)

def gnormal_cdf(x, p):
  return .5 + jnp.sign(x) * gammainc(1 / p, jnp.abs(x) ** p) / 2

def test_gnormal():
  key = random.PRNGKey(0)
  for p in [.5, 1., 1.5, 2., 2.5]:
    for shape in [(), (5,), (10, 5)]:
      for dtype in ['float16', 'float32']:
        rvs = gnormal(key, p, shape, dtype)
        assert rvs.shape == shape
        assert rvs.dtype == dtype

    rvs = gnormal(key, p, [10**5])
    result = kstest(rvs, gnormal_cdf, (p,))
    assert result.pvalue > .01
    print(result)

if __name__ == '__main__':
  test_gnormal()
KstestResult(statistic=0.0033879876, pvalue=0.20072515586074036)
KstestResult(statistic=0.0032510757, pvalue=0.24058773176583947)
KstestResult(statistic=0.0031290352, pvalue=0.28085765527744644)
KstestResult(statistic=0.003481686, pvalue=0.17652514767527894)
KstestResult(statistic=0.0032874048, pvalue=0.22947315482103625)
jakevdp commented 2 years ago

Thanks - I think we would be happy to accept a contribution of this function to jax.random. Are you interested in putting together a pull request?

carlosgmartin commented 2 years ago

@jakevdp Sure. Since #10876 depends on this one, should I submit a PR for this one first, or one PR for both simultaneously?

jakevdp commented 2 years ago

You can do a single PR for both, if that's easier

carlosgmartin commented 2 years ago

@jakevdp Do you think jax.scipy.gennorm.pdf/logpdf/cdf should be added as well?

jakevdp commented 2 years ago

Sure, I think that would be a useful contribution - thanks!