ciemss / program-milestones

Repository for materials related to program milestone hackathon and evaluation events
0 stars 0 forks source link

NovDemo-4: Hierarchical modeling #16

Open sabinala opened 1 week ago

sabinala commented 1 week ago

Scenario 4: Hierarchical Modeling (see here)

In this scenario, we have simulated data from a geographic region. This region is made up of three counties, A, C, and C, and each county consists of two cities. The cities are numbered 1-6. Cities in the same county have similar transmission rates (drawn from the same distribution), but transmission rates vary stochastically over time. Counties do not share similar transmission rates with other counties.

The simulated dataset has 100 days of data for each city. Each day contains 10 timesteps, so there are 1000 observations for each city.

Screenshot 2024-11-12 at 4 41 16 PM

Epidemiological data can be imperfect, and you will find that to be the case within these counties. Specifically:

The governor of the region is interested in answering a series of questions about the cities and counties. The data is recorded in an SEIR format, which follows these equations:

\begin{align}
\frac{dS}{dt} &= - \beta S \frac{I}{N} \\
\frac{dE}{dt} &= \beta S \frac{I}{N} - \alpha E \\
\frac{dI}{dt} &= \alpha E - \gamma I \\
\frac{dR}{dt} &= \gamma I
\end{align}

The initial conditions of the cities and the data for each city are provided in the S4_data.csv file. (Initial conditions are also listed below)

Screenshot 2024-11-12 at 4 56 48 PM

Problems:

  1. Estimate city-level transmission rates. (a) use the data provided to estimate $\beta$ for Cities 1-5 without incorporating information from any other cities. (b) For County C, City 6, note that the intervention was implemented starting at t = 30 days. Estimate the transmission rate before and after the intervention took place.
  2. Using the data provided, estimate both county and regional transmission rates. Additionally, estimate transmission rates for each city, this time by incorporating information from other cities. Note any differences from the result in Q1.
  3. For County A, City 1, impute the missing chunk of data using (1) the model that incorporated information from other cities and (2) the model which did not incorporate or include information from other cities. Plot both estimated SEIR curves and compare.
  4. Repeat the same exercise in Q3 for County B, City 3.
  5. (Counterfactuals) Imagine that the intervention applied in County C, City 6 had also been applied to County B, City 3 and City 4, at the same time. How would this have reduced the total number of infections in County B across the duration of time that the data was collected?
  6. (Optimization – choose a geography) The governor of the region is aware of a new variant that has started to spread in a different region, and the governor fears that it may cause potential damage to the cities and counties they oversee. The best estimate of the transmission rate for this new variant is 1.2x as transmissible as previous variants.

    The governor has funding to apply an intervention in 2 of the 6 cities and is interested in minimizing the total number of infections in each city. The intervention is expected to reduce the transmission rate of the new variant to 80% of its previous transmissibility rate. The intervention would start at t=100 timesteps (day 10) and can run for the rest of the period (up to day 100). For this question, initial conditions are the same as Table 1.

    In which cities should the governor implement the intervention? By implementing the intervention, what are the total number of infections the governor would expect over a 1000-timestep (100 day) period, and what are the total number of infections the governor would expect if they chose to do nothing? Provide 90% intervals on all projections.

djinnome commented 6 days ago

OK, I am running into issues with render_model. Copilot says the following:

The error message indicates a shape mismatch inside the counties plate at the S_obs site. This is likely due to the incorrect handling of the dimensions when defining the plates and the observation model. To fix this, we need to ensure that the dimensions of the plates and the observation model match correctly.

@SamWitty said:

I can help diagnose with the actual error message and code

At the bottom of: https://github.com/ciemss/program-milestones/blob/16-novdemo-4/Nov-24-monthly-epi-demo/Scenario-4/dynamical_multi.ipynb

def rendering_model(n_counties_per_region, n_cities_per_county) -> State[torch.Tensor]:

    alpha, beta, gamma, counties, cities = bayesian_multilevel_seir_prior(n_counties_per_region, n_cities_per_county)
    n_cities = n_counties_per_region * n_cities_per_county
    seir = SEIRDynamics(alpha, beta, gamma)
    state = dict(
        S=torch.ones( n_cities_per_county, n_counties_per_region) * 99, 
        E=torch.ones( n_cities_per_county, n_counties_per_region), 
        I=torch.zeros( n_cities_per_county, n_counties_per_region,), 
        R=torch.zeros( n_cities_per_county, n_counties_per_region)
    )
    deriv = seir(state)
    state = {k: v + deriv[k] * 0.1 for k, v in state.items()}
    deriv = seir(state)
    state = {k: v + deriv[k] * 0.1 for k, v in state.items()}
    with counties:
        with cities:
            state = {k: pyro.sample(k, dist.Delta(v)) for k, v in state.items()}
            with pyro.condition(
                data={"I_obs": torch.ones(n_cities_per_county, n_counties_per_region), 
                      "R_obs": torch.zeros(n_cities_per_county, n_counties_per_region)}
            ):
                seir_observation_model(state)
pyro.render_model(rendering_model, model_args=(n_counties_per_region, n_cities_per_county), render_deterministic=True)

results in:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:191, in TraceHandler.__call__(self, *args, **kwargs)
    [190](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:190) try:
--> [191](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:191)     ret = self.fn(*args, **kwargs)
    [192](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:192) except (ValueError, RuntimeError) as e:

Cell In[52], [line 23](vscode-notebook-cell:?execution_count=52&line=23)
     [19](vscode-notebook-cell:?execution_count=52&line=19) with pyro.condition(
     [20](vscode-notebook-cell:?execution_count=52&line=20)     data={"I_obs": torch.ones(n_cities_per_county, n_counties_per_region), 
     [21](vscode-notebook-cell:?execution_count=52&line=21)           "R_obs": torch.zeros(n_cities_per_county, n_counties_per_region)}
     [22](vscode-notebook-cell:?execution_count=52&line=22) ):
---> [23](vscode-notebook-cell:?execution_count=52&line=23)     seir_observation_model(state)

Cell In[10], [line 22](vscode-notebook-cell:?execution_count=10&line=22)
     [21](vscode-notebook-cell:?execution_count=10&line=21) event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0
---> [22](vscode-notebook-cell:?execution_count=10&line=22) pyro.sample(
     [23](vscode-notebook-cell:?execution_count=10&line=23)     "S_obs", dist.Normal(X["S"], torch.as_tensor(1.0)).to_event(event_dim)
     [24](vscode-notebook-cell:?execution_count=10&line=24) )  # noisy number of susceptible actually observed
     [25](vscode-notebook-cell:?execution_count=10&line=25) pyro.sample(
     [26](vscode-notebook-cell:?execution_count=10&line=26)     "E_obs", dist.Normal(X["E"], torch.as_tensor(1.0)).to_event(event_dim)
     [27](vscode-notebook-cell:?execution_count=10&line=27) )  # noisy number of exposed actually observed

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:189, in sample(name, fn, obs, obs_mask, infer, *args, **kwargs)
    [188](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:188) # apply the stack and return its return value
--> [189](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:189) apply_stack(msg)
    [190](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:190) assert msg["value"] is not None  # type narrowing guaranteed by apply_stack

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:378, in apply_stack(initial_msg)
    [376](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:376) pointer = pointer + 1
--> [378](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:378) frame._process_message(msg)
    [380](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:380) if msg["stop"]:

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/plate_messenger.py:25, in PlateMessenger._process_message(self, msg)
     [24](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/plate_messenger.py:24) super()._process_message(msg)
---> [25](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/plate_messenger.py:25) BroadcastMessenger._pyro_sample(msg)

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     [80](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:80) with self._recreate_cm():
---> [81](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:81)     return func(*args, **kwds)

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:74, in BroadcastMessenger._pyro_sample(msg)
     [70](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:70) if (
     [71](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:71)     target_batch_shape[f.dim] is not None
     [72](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:72)     and target_batch_shape[f.dim] != f.size
     [73](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:73) ):
---> [74](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:74)     raise ValueError(
     [75](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:75)         "Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
     [76](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:76)             f.name,
     [77](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:77)             msg["name"],
     [78](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:78)             f.dim,
     [79](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:79)             f.size,
     [80](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:80)             target_batch_shape[f.dim],
     [81](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:81)         )
     [82](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:82)     )
     [83](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:83) target_batch_shape[f.dim] = f.size

ValueError: Shape mismatch inside plate('counties') at site S_obs dim -1, 3 vs 2

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[52], [line 24](vscode-notebook-cell:?execution_count=52&line=24)
     [19](vscode-notebook-cell:?execution_count=52&line=19)             with pyro.condition(
     [20](vscode-notebook-cell:?execution_count=52&line=20)                 data={"I_obs": torch.ones(n_cities_per_county, n_counties_per_region), 
     [21](vscode-notebook-cell:?execution_count=52&line=21)                       "R_obs": torch.zeros(n_cities_per_county, n_counties_per_region)}
     [22](vscode-notebook-cell:?execution_count=52&line=22)             ):
     [23](vscode-notebook-cell:?execution_count=52&line=23)                 seir_observation_model(state)
---> [24](vscode-notebook-cell:?execution_count=52&line=24) pyro.render_model(rendering_model, model_args=(n_counties_per_region, n_cities_per_county), render_deterministic=True)

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:630, in render_model(model, model_args, model_kwargs, filename, render_distributions, render_params, render_deterministic)
    [627](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:627) # Get model relations.
    [628](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:628) if not isinstance(model_args, list) and not isinstance(model_kwargs, list):
    [629](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:629)     relations = [
--> [630](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:630)         get_model_relations(
    [631](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:631)             model,
    [632](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:632)             model_args,
    [633](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:633)             model_kwargs,
    [634](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:634)             include_deterministic=render_deterministic,
    [635](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:635)         )
    [636](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:636)     ]
    [637](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:637) else:  # semisupervised
    [638](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:638)     if isinstance(model_args, list):

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:302, in get_model_relations(model, model_args, model_kwargs, include_deterministic)
    [300](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:300) with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
    [301](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:301)     with TrackProvenance(include_deterministic=include_deterministic):
--> [302](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:302)         trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
    [304](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:304) sample_sample = {}
    [305](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/infer/inspect.py:305) sample_param = {}

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:216, in TraceHandler.get_trace(self, *args, **kwargs)
    [208](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:208) def get_trace(self, *args, **kwargs) -> Trace:
    [209](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:209)     """
    [210](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:210)     :returns: data structure
    [211](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:211)     :rtype: pyro.poutine.Trace
   (...)
    [214](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:214)     Calls this poutine and returns its trace instead of the function's return value.
    [215](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:215)     """
--> [216](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:216)     self(*args, **kwargs)
    [217](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:217)     return self.msngr.get_trace()

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.__call__(self, *args, **kwargs)
    [196](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:196)         exc = exc_type("{}\n{}".format(exc_value, shapes))
    [197](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:197)         exc = exc.with_traceback(traceback)
--> [198](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:198)         raise exc from e
    [199](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:199)     self.msngr.trace.add_node(
    [200](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:200)         "_RETURN", name="_RETURN", type="return", value=ret
    [201](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:201)     )
    [202](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:202) return ret

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:191, in TraceHandler.__call__(self, *args, **kwargs)
    [187](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:187) self.msngr.trace.add_node(
    [188](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:188)     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    [189](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:189) )
    [190](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:190) try:
--> [191](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:191)     ret = self.fn(*args, **kwargs)
    [192](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:192) except (ValueError, RuntimeError) as e:
    [193](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/trace_messenger.py:193)     exc_type, exc_value, traceback = sys.exc_info()

Cell In[52], [line 23](vscode-notebook-cell:?execution_count=52&line=23)
     [18](vscode-notebook-cell:?execution_count=52&line=18) state = {k: pyro.sample(k, dist.Delta(v)) for k, v in state.items()}
     [19](vscode-notebook-cell:?execution_count=52&line=19) with pyro.condition(
     [20](vscode-notebook-cell:?execution_count=52&line=20)     data={"I_obs": torch.ones(n_cities_per_county, n_counties_per_region), 
     [21](vscode-notebook-cell:?execution_count=52&line=21)           "R_obs": torch.zeros(n_cities_per_county, n_counties_per_region)}
     [22](vscode-notebook-cell:?execution_count=52&line=22) ):
---> [23](vscode-notebook-cell:?execution_count=52&line=23)     seir_observation_model(state)

Cell In[10], [line 22](vscode-notebook-cell:?execution_count=10&line=22)
     [18](vscode-notebook-cell:?execution_count=10&line=18) def seir_observation_model(X: State[torch.Tensor]) -> None:
     [19](vscode-notebook-cell:?execution_count=10&line=19)     # Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the seir_observation_model
     [20](vscode-notebook-cell:?execution_count=10&line=20)     # can be used for both single and multi-dimensional observations.
     [21](vscode-notebook-cell:?execution_count=10&line=21)     event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0
---> [22](vscode-notebook-cell:?execution_count=10&line=22)     pyro.sample(
     [23](vscode-notebook-cell:?execution_count=10&line=23)         "S_obs", dist.Normal(X["S"], torch.as_tensor(1.0)).to_event(event_dim)
     [24](vscode-notebook-cell:?execution_count=10&line=24)     )  # noisy number of susceptible actually observed
     [25](vscode-notebook-cell:?execution_count=10&line=25)     pyro.sample(
     [26](vscode-notebook-cell:?execution_count=10&line=26)         "E_obs", dist.Normal(X["E"], torch.as_tensor(1.0)).to_event(event_dim)
     [27](vscode-notebook-cell:?execution_count=10&line=27)     )  # noisy number of exposed actually observed
     [28](vscode-notebook-cell:?execution_count=10&line=28)     pyro.sample(
     [29](vscode-notebook-cell:?execution_count=10&line=29)         "I_obs", dist.Normal(X["I"], torch.as_tensor(1.0)).to_event(event_dim)
     [30](vscode-notebook-cell:?execution_count=10&line=30)     )  # noisy number of infected actually observed

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:189, in sample(name, fn, obs, obs_mask, infer, *args, **kwargs)
    [172](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:172) msg = Message(
    [173](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:173)     type="sample",
    [174](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:174)     name=name,
   (...)
    [186](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:186)     continuation=None,
    [187](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:187) )
    [188](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:188) # apply the stack and return its return value
--> [189](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:189) apply_stack(msg)
    [190](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:190) assert msg["value"] is not None  # type narrowing guaranteed by apply_stack
    [191](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/primitives.py:191) return msg["value"]

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:378, in apply_stack(initial_msg)
    [375](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:375) for frame in reversed(stack):
    [376](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:376)     pointer = pointer + 1
--> [378](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:378)     frame._process_message(msg)
    [380](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:380)     if msg["stop"]:
    [381](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/runtime.py:381)         break

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/plate_messenger.py:25, in PlateMessenger._process_message(self, msg)
     [23](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/plate_messenger.py:23) def _process_message(self, msg: "Message") -> None:
     [24](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/plate_messenger.py:24)     super()._process_message(msg)
---> [25](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/plate_messenger.py:25)     BroadcastMessenger._pyro_sample(msg)

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     [78](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:78) @wraps(func)
     [79](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:79) def inner(*args, **kwds):
     [80](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:80)     with self._recreate_cm():
---> [81](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/contextlib.py:81)         return func(*args, **kwds)

File ~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:74, in BroadcastMessenger._pyro_sample(msg)
     [69](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:69)     target_batch_shape = prefix_batch_shape + target_batch_shape
     [70](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:70)     if (
     [71](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:71)         target_batch_shape[f.dim] is not None
     [72](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:72)         and target_batch_shape[f.dim] != f.size
     [73](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:73)     ):
---> [74](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:74)         raise ValueError(
     [75](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:75)             "Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
     [76](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:76)                 f.name,
     [77](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:77)                 msg["name"],
     [78](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:78)                 f.dim,
     [79](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:79)                 f.size,
     [80](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:80)                 target_batch_shape[f.dim],
     [81](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:81)             )
     [82](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:82)         )
     [83](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:83)     target_batch_shape[f.dim] = f.size
     [84](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:84) # Starting from the right, if expected size is None at an index,
     [85](https://file+.vscode-resource.vscode-cdn.net/Users/zuck016/Projects/ASKEM/program-milestones/Nov-24-monthly-epi-demo/Scenario-4/~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pyro/poutine/broadcast_messenger.py:85) # set it to the actual size if it exists, else 1.

ValueError: Shape mismatch inside plate('counties') at site S_obs dim -1, 3 vs 2
  Trace Shapes:      
   Param Sites:      
  Sample Sites:      
  counties dist     |
          value   3 |
alpha_mean dist   3 |
          value   3 |
 beta_mean dist   3 |
          value   3 |
gamma_mean dist   3 |
          value   3 |
    cities dist     |
          value   2 |
     alpha dist 2 3 |
          value 2 3 |
      beta dist 2 3 |
          value 2 3 |
     gamma dist 2 3 |
          value 2 3 |
         S dist 2 3 |
          value 2 3 |
         E dist 2 3 |
          value 2 3 |
         I dist 2 3 |
          value 2 3 |
         R dist 2 3 |
          value 2 3 |
djinnome commented 3 days ago

@SamWitty said:

I'm shocked that cell 16 ran, given that the samples plate is using dim=-2. Try adding explicit dim arguments to each of the plates. E.g.

counties = pyro.plate("counties", size=n_counties_per_region, dim=-1)
cities = pyro.plate("cities", size=n_cities_per_county, dim=-2)
with pyro.plate("samples", num_samples * 5, dim=-3)

It's possible the issue is in the shape of tensors used for I_obs and R_obs in the pyro.condition handler in the cell that's erroring. If the above suggestion doesn't fix it, try removing the pyro.condition statement to see what the tensor shapes are from the seir_observation_model and then just match those shapes. I'm betting you're missing a tensor dimension somewhere.

Based on this advice, I made it much farther, but still not quite right:

import os
os.environ["PATH"] += os.pathsep + "/opt/homebrew/bin"
def rendering_model(n_counties_per_region, n_cities_per_county) -> State[torch.Tensor]:

    alpha, beta, gamma, counties, cities = bayesian_multilevel_seir_prior(n_counties_per_region, n_cities_per_county)
    n_cities = n_counties_per_region * n_cities_per_county
    seir = SEIRDynamics(alpha, beta, gamma)
    state = dict(
        S=torch.ones( n_cities_per_county, n_counties_per_region) * 99, 
        E=torch.ones( n_cities_per_county, n_counties_per_region), 
        I=torch.zeros( n_cities_per_county, n_counties_per_region,), 
        R=torch.zeros( n_cities_per_county, n_counties_per_region)
    )
    deriv = seir(state)
    state = {k: v + deriv[k] * 0.1 for k, v in state.items()}
    deriv = seir(state)
    state = {k: v + deriv[k] * 0.1 for k, v in state.items()}
    deriv = seir(state)
    state = {k: v + deriv[k] * 0.1 for k, v in state.items()}

    with cities:
        #with counties:
            state = {k: pyro.sample(k, dist.Delta(v)) for k, v in state.items()}
            with pyro.condition(
                data={"I_obs": torch.ones(n_cities_per_county, n_counties_per_region), 
                      "R_obs": torch.zeros(n_cities_per_county, n_counties_per_region)}
            ):
                seir_observation_model(state)

    return state
pyro.render_model(rendering_model, model_args=(n_counties_per_region, n_cities_per_county), render_deterministic=True)

multi_seir

But I can't nest cities and counties, and the multiple calls to state and deriv don't seem to create extra edges between the alpha beta and gamma parameters and the S E I R state variables.

djinnome commented 21 hours ago

I think I figured out the problem here. This code solves the problem by making event_dim=0:

    event_dim = 1 if X["I"].shape and X["I"].shape[-1] > batch_dim else 0

multi_seir