Closed AFg6K7h4fhy2 closed 1 week ago
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 94.76%. Comparing base (
657a206
) to head (dd5de7e
).
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
In test_hospitalizations.py
(and other test files) would people like me to also change the following behavior (example) or it's fine?
model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(272),
observed_hosp_admissions=model1_samp.sampled_observed_hosp_admissions,
)
I grandfathered in the type jax.random.PRNGKey
in
def run(
self,
num_warmup,
num_samples,
rng_key: jax.random.PRNGKey | None = None,
nuts_args: dict = None,
mcmc_args: dict = None,
**kwargs,
) -> None:
However, the tests fail here, because of jax.random.PRNGKey
. Inspection suggests that ArrayLike
from from jax.typing import ArrayLike
is the correct type to be using:
>>> import jax.random as jr
>>> rand_int_key = jr.PRNGKey(0)
>>> rand_int = jr.randint(
... rand_int_key, shape=(), minval=0, maxval=100000
... )
>>> rng_key = jr.PRNGKey(rand_int)
>>> type(rng_key)
<class 'jaxlib.xla_extension.ArrayImpl'>
Might be a good time to switch to the new style of random keys https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#jep-9263
Some links:
For the following, ArrayLike
still works, but (see last link) KeyArray
might be implemented as a type soon.
>>> import jax.random as jr
>>> import numpy as np
>>> rand_int = np.random.randint(0, 100000)
>>> rng_key = jr.key(rand_int)
>>> type(rng_key)
<class 'jax._src.prng.PRNGKeyArray'>
NOTE: At some point, having rng_key: ArrayLike | None = None
might be problematic in the sense that, while jax.random.key()
and jax.random.PRNGKey()
both are ArrayLike
(for the time being, a KeyArray
type in jax.typing
might be made soon), the ArrayLike
has to be from jax.random.key()
or jax.random.PRNGKey()
not from a typical JAX array, e.g. jax.numpy.array([2,3,1])
.
Can we add a test that checks that two models with no rng key do not produce the same results?
In
metaclass.py
, the MCMC run require a pseudo-random key (jax.random.PRNGKey()
). Without this PR, if the user does not specify their own PRNGKey, MSR instantiates its own PRNGKey using a magic number (54). With this pull request, a random integer between 0 and 100000 is generated and this integer is then used for the PRNGKey when, as a function argument, the PRNGKey is None.