aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

Possible small error in `GenGammaRV`: `"rng_state"` instead of `"jax_state"` #1472

Closed PaulScemama closed 1 year ago

PaulScemama commented 1 year ago

My JAX internals knowledge is less than stellar at the moment, so I haven't dug deep and this seems like a transparent bug.

I noticed in aesara/link/jax/dispatch/random.py the GenGammaRV implementation is:

@jax_sample_fn.register(aer.GenGammaRV)
def jax_sample_fn_gengamma(op):
    r"""Provide a JAX implementation of `GenGammaRV`.

    Samples are obtained from inverse sampling using the following:

    .. math::

        F^{-1}(q; a, d, p) = a \left( G^{-1}(q) \right)^{1/p}

    where :math:`G` is the CDF of a gamma distribution with
    :math:`\alpha = d/p` and :math:`\beta = 1`.

    .. note::

        Here we use the parametrization :math:`\alpha = d/p`.

    """

    def sample_fn(rng, size, dtype, *parameters):
        rng_key = rng["jax_state"]
        rng_key, sampling_key = jax.random.split(rng_key, 2)

        alpha, lam, p = parameters
        d = alpha / p
        samples = jax.random.gamma(sampling_key, d, size, dtype)
        samples = lam * samples ** (1.0 / p)

        rng["rng_state"] = rng_key
        return (rng, samples)

    return sample_fn

Notice at the end we have rng["rng_state"] = rng_key. I noticed that in every other implementation it is intead rng["jax_state"] = rng_key. I just wanted to point it out, it's a very very tiny error!

brandonwillard commented 1 year ago

Thanks for pointing that out!