BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
172 stars 12 forks source link

allow for batched inference with observational SIR model, add test for batched inference #566

Open rfl-urbaniak opened 2 months ago

rfl-urbaniak commented 2 months ago

The sir observational model wasn't general enough to guide the user in building models for inference with batched data. It made a gesture in this direction by allowing:

# Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the sir_observation_model
    # can be used for both single and multi-dimensional observations.
    event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0

But if a model of this type is passed to SVI inference, the log prob shapes are wrong. To ensure their correctness, we also need to introduce a plate of appropriate shape.

To illustrate and to ensure proper functionality, I added dynamical/test_batched_inference.py. For illustration, commenting out the lines introducing the plate in that test model (and un-indenting the pyro.sample statements) will lead to the type of log prob shape error in question.

Accordingly, I revised the dynamical systems notebook. The sir observational model now is:

def sir_observation_model(X: State[torch.Tensor]) -> None:
    # We don't observe the number of susceptible individuals directly.

    # Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the sir_observation_model
    # can be used for both single and multi-dimensional observations.
    event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0

    # Note: such plating while not necessary for this example,
    # would be needed to ensure proper log prob shapes
    # in inference with multiple observed time series,
    # so we include it for illustrative purposes.
    n = X["I"].shape[-2] if len(X["I"].shape) >= 2 else 1
    with pyro.plate("data", n, dim=-2):
        pyro.sample("I_obs", dist.Poisson(X["I"]).to_event(event_dim))  # noisy number of infected actually observed
        pyro.sample("R_obs", dist.Poisson(X["R"]).to_event(event_dim))  # noisy number of recovered actually observed

Otherwise, small changes, including adding plot.show() and partial prediction parallelization:

sir_data = dict(**{k:tr.trace.nodes[k]["value"] for k in ["I_obs", "R_obs"]})```

to

sir_data = dict(**{k:tr.trace.nodes[k]["value"].view(-1) for k in ["I_obs", "R_obs"]})
plt.xlim(start_time, end_time)
plt.xlabel("Time (Months)")
plt.ylabel("# of Individuals (Millions)")
plt.legend(loc="upper right")

to:

plt.xlim(start_time, end_time)
plt.xlabel("Time (Months)")
plt.ylabel("# of Individuals (Millions)")
plt.legend(loc="upper right")
plt.show()

(added plt.show()at a few locations to avoid redundant printing of object names before plotting)

# Generate samples from the posterior predictive distribution
sir_predictive = Predictive(simulated_bayesian_sir, guide=sir_guide, num_samples=num_samples)
sir_posterior_samples = sir_predictive(init_state, start_time, logging_times)

to

# Generate samples from the posterior predictive distribution
sir_predictive = Predictive(simulated_bayesian_sir, guide=sir_guide, num_samples=num_samples, parallel = True)
sir_posterior_samples = sir_predictive(init_state, start_time, logging_times)

and

intervened_sir_predictive = Predictive(intervened_sir, guide=sir_guide, num_samples=num_samples)
intervened_sir_posterior_samples = intervened_sir_predictive(lockdown_start, lockdown_end, lockdown_strength, init_state_lockdown, start_time, logging_times)

to

intervened_sir_predictive = Predictive(intervened_sir, guide=sir_guide, num_samples=num_samples, parallel=True)
intervened_sir_posterior_samples = intervened_sir_predictive(lockdown_start, lockdown_end, lockdown_strength, init_state_lockdown, start_time, logging_times)

There seems to be a small shape-related bug in the notebook that leads to runtime error with parallelizaton at a few locations. It remains unfixed. The locations are:

uncertain_intervened_sir_predictive = Predictive(uncertain_intervened_sir, guide=sir_guide, num_samples=num_samples)
uncertain_intervened_sir_posterior_samples = uncertain_intervened_sir_predictive(lockdown_strength, init_state_lockdown, start_time, logging_times)

and


dynamic_intervened_sir_predictive = Predictive(dynamic_intervened_sir, guide=sir_guide, num_samples=num_samples)
dynamic_intervened_sir_posterior_samples = dynamic_intervened_sir_predictive(lockdown_trigger, lockdown_lift_trigger, lockdown_strength, init_state_lockdown, start_time, logging_times)

and

uncertain_dynamic_intervened_sir_predictive = Predictive(uncertain_dynamic_intervened_sir, guide=sir_guide, num_samples=num_samples)
uncertain_dynamic_intervened_sir_posterior_samples = (uncertain_dynamic_intervened_sir_predictive(lockdown_strength, init_state_lockdown, start_time, logging_times))

The whole notebook has been re-run.