My JAX internals knowledge is less than stellar at the moment, so I haven't dug deep and this seems like a transparent bug.
I noticed in aesara/link/jax/dispatch/random.py the GenGammaRV implementation is:
@jax_sample_fn.register(aer.GenGammaRV)
def jax_sample_fn_gengamma(op):
r"""Provide a JAX implementation of `GenGammaRV`.
Samples are obtained from inverse sampling using the following:
.. math::
F^{-1}(q; a, d, p) = a \left( G^{-1}(q) \right)^{1/p}
where :math:`G` is the CDF of a gamma distribution with
:math:`\alpha = d/p` and :math:`\beta = 1`.
.. note::
Here we use the parametrization :math:`\alpha = d/p`.
"""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
alpha, lam, p = parameters
d = alpha / p
samples = jax.random.gamma(sampling_key, d, size, dtype)
samples = lam * samples ** (1.0 / p)
rng["rng_state"] = rng_key
return (rng, samples)
return sample_fn
Notice at the end we have rng["rng_state"] = rng_key. I noticed that in every other implementation it is intead rng["jax_state"] = rng_key. I just wanted to point it out, it's a very very tiny error!
My
JAX
internals knowledge is less than stellar at the moment, so I haven't dug deep and this seems like a transparent bug.I noticed in
aesara/link/jax/dispatch/random.py
theGenGammaRV
implementation is:Notice at the end we have
rng["rng_state"] = rng_key
. I noticed that in every other implementation it is inteadrng["jax_state"] = rng_key
. I just wanted to point it out, it's a very very tiny error!