CDCgov / multisignal-epi-inference

Python package for statistical inference and forecast of epi models using multiple signals
https://cdcgov.github.io/multisignal-epi-inference/
9 stars 1 forks source link

Demonstrate use of predictive distributions in tutorials using ArviZ #221

Open damonbayer opened 3 days ago

damonbayer commented 3 days ago

Goal

Our use of predictive distributions should be compatible with ArviZ, so we can easily make compelling plots using plot_lm, plot_ppc, plot_ts, and/or plot_hdi.

Example from plot_lm: image

Except, we don't need the uncertainty in the mean and the colors should be more distinct from each other.

I would like something that looked closer to this:

posterior_predictive_ifr_age_structure_plot

If using ArviZ for this is too hacky, I would be open to creating our own solution.

damonbayer commented 3 days ago

Relevant comment from @sbidari: https://github.com/CDCgov/multisignal-epi-inference/pull/218#issuecomment-2195504780

Potential starter code from earlier version of https://github.com/CDCgov/multisignal-epi-inference/pull/218/:

idata_weekday = az.from_numpyro(hosp_model_weekday.mcmc, posterior_predictive=hosp_model_weekday.posterior_predictive(n_timepoints_to_simulate=1))
# | label: fig-output-infections-distribution
# | fig-cap: Posterior Latent Infections
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
    idata_weekday.posterior_predictive["negbinom_rv"],
    hdi_prob=0.9,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)
az.plot_hdi(
    idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
    idata_weekday.posterior_predictive["negbinom_rv"],
    hdi_prob=0.5,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)
# Add mean of the posterior to the figure
mean_latent_infection = np.mean(
    idata_weekday.posterior_predictive["negbinom_rv"], axis=1
)
plt.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
plt.scatter(
    idata_weekday.observed_data["negbinom_rv_dim_0"]
    + gen_int_array.size
    + days_to_impute,
    idata_weekday.observed_data["negbinom_rv"],
    color="black",
)
axes.legend()
axes.set_title("Posterior Predictive Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Observed Infections", fontsize=10);