CDCgov / multisignal-epi-inference

Python package for statistical inference and forecast of epi models using multiple signals
https://cdcgov.github.io/multisignal-epi-inference/
9 stars 1 forks source link

Pseudo-Random Number Generation For None Input For JAX PRNGKey In Metaclass #192

Closed AFg6K7h4fhy2 closed 1 week ago

AFg6K7h4fhy2 commented 2 weeks ago

In metaclass.py, the MCMC run require a pseudo-random key (jax.random.PRNGKey()). Without this PR, if the user does not specify their own PRNGKey, MSR instantiates its own PRNGKey using a magic number (54). With this pull request, a random integer between 0 and 100000 is generated and this integer is then used for the PRNGKey when, as a function argument, the PRNGKey is None.

codecov[bot] commented 2 weeks ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 94.76%. Comparing base (657a206) to head (dd5de7e).

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #192 +/- ## ========================================== + Coverage 94.73% 94.76% +0.02% ========================================== Files 40 40 Lines 893 898 +5 ========================================== + Hits 846 851 +5 Misses 47 47 ``` | [Flag](https://app.codecov.io/gh/CDCgov/multisignal-epi-inference/pull/192/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/CDCgov/multisignal-epi-inference/pull/192/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov) | `94.76% <100.00%> (+0.02%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

AFg6K7h4fhy2 commented 2 weeks ago

In test_hospitalizations.py (and other test files) would people like me to also change the following behavior (example) or it's fine?

model1.run(
        num_warmup=500,
        num_samples=500,
        rng_key=jax.random.PRNGKey(272),
        observed_hosp_admissions=model1_samp.sampled_observed_hosp_admissions,
    )
AFg6K7h4fhy2 commented 2 weeks ago

I grandfathered in the type jax.random.PRNGKey in

def run(
        self,
        num_warmup,
        num_samples,
        rng_key: jax.random.PRNGKey | None = None,
        nuts_args: dict = None,
        mcmc_args: dict = None,
        **kwargs,
    ) -> None:

However, the tests fail here, because of jax.random.PRNGKey. Inspection suggests that ArrayLike from from jax.typing import ArrayLike is the correct type to be using:

>>> import jax.random as jr
>>> rand_int_key = jr.PRNGKey(0)
>>> rand_int = jr.randint(
...                 rand_int_key, shape=(), minval=0, maxval=100000
...             )
>>> rng_key = jr.PRNGKey(rand_int)
>>> type(rng_key)
<class 'jaxlib.xla_extension.ArrayImpl'>
damonbayer commented 2 weeks ago

Might be a good time to switch to the new style of random keys https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#jep-9263

AFg6K7h4fhy2 commented 2 weeks ago

Some links:

For the following, ArrayLike still works, but (see last link) KeyArray might be implemented as a type soon.

>>> import jax.random as jr
>>> import numpy as np
>>> rand_int = np.random.randint(0, 100000)
>>> rng_key = jr.key(rand_int)
>>> type(rng_key)
<class 'jax._src.prng.PRNGKeyArray'>
AFg6K7h4fhy2 commented 2 weeks ago

NOTE: At some point, having rng_key: ArrayLike | None = None might be problematic in the sense that, while jax.random.key() and jax.random.PRNGKey() both are ArrayLike (for the time being, a KeyArray type in jax.typing might be made soon), the ArrayLike has to be from jax.random.key() or jax.random.PRNGKey() not from a typical JAX array, e.g. jax.numpy.array([2,3,1]).

damonbayer commented 2 weeks ago

Can we add a test that checks that two models with no rng key do not produce the same results?