CDCgov / PyRenew

Python package for multi-signal Bayesian renewal modeling with JAX and NumPyro.
https://cdcgov.github.io/PyRenew/
Apache License 2.0
14 stars 2 forks source link

Issues related to using predictive methods multiple times on the same model #282

Closed damonbayer closed 1 day ago

damonbayer commented 1 month ago

This (and other similar exercises) leads to an error.

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
import pyrenew.transformation as t
from numpy.testing import assert_array_equal, assert_raises
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
    InfectionInitializationProcess,
    Infections,
    InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import RtRandomWalkProcess

pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
I0 = InfectionInitializationProcess(
    "I0_initialization",
    DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
    InitializeInfectionsZeroPad(n_timepoints=gen_int.size()),
    t_unit=1,
)
latent_infections = Infections()
observed_infections = PoissonObservation("poisson_rv")
rt = RtRandomWalkProcess(
    Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
    Rt_transform=t.ExpTransform().inv,
    Rt_rw_dist=dist.Normal(0, 0.025),
)
model = RtInfectionsRenewalModel(
    I0_rv=I0,
    gen_int_rv=gen_int,
    latent_infections_rv=latent_infections,
    infection_obs_process_rv=observed_infections,
    Rt_process_rv=rt,
)

n_tp = 30

model.prior_predictive(
    numpyro_predictive_args={"num_samples": 20},
    n_timepoints_to_simulate=n_tp,
)

model.prior_predictive(
    numpyro_predictive_args={"num_samples": 20},
    n_timepoints_to_simulate=n_tp,
)
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <lambda> at /Users/damon/Library/Caches/pypoetry/virtualenvs/pyrenew-GjeTh4Fr-py3.12/lib/python3.12/site-packages/jax/_src/lax/control_flow/loops.py:2111 traced for scan.
------------------------------
dylanhmorris commented 1 month ago

This should be resolved by https://github.com/pyro-ppl/numpyro/pull/1843

dylanhmorris commented 1 day ago

This is resolved.