The sir observational model wasn't general enough to guide the user in building models for inference with batched data. It made a gesture in this direction by
allowing:
# Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the sir_observation_model
# can be used for both single and multi-dimensional observations.
event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0
But if a model of this type is passed to SVI inference, the log prob shapes are wrong. To ensure their correctness, we also need to introduce a plate of appropriate shape.
To illustrate and to ensure proper functionality, I added dynamical/test_batched_inference.py. For illustration, commenting out the lines introducing the plate in that test model (and un-indenting the pyro.sample statements) will lead to the type of log prob shape error in question.
Accordingly, I revised the dynamical systems notebook. The sir observational model now is:
def sir_observation_model(X: State[torch.Tensor]) -> None:
# We don't observe the number of susceptible individuals directly.
# Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the sir_observation_model
# can be used for both single and multi-dimensional observations.
event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0
# Note: such plating while not necessary for this example,
# would be needed to ensure proper log prob shapes
# in inference with multiple observed time series,
# so we include it for illustrative purposes.
n = X["I"].shape[-2] if len(X["I"].shape) >= 2 else 1
with pyro.plate("data", n, dim=-2):
pyro.sample("I_obs", dist.Poisson(X["I"]).to_event(event_dim)) # noisy number of infected actually observed
pyro.sample("R_obs", dist.Poisson(X["R"]).to_event(event_dim)) # noisy number of recovered actually observed
Otherwise, small changes, including adding plot.show() and partial prediction parallelization:
sir_data = dict(**{k:tr.trace.nodes[k]["value"] for k in ["I_obs", "R_obs"]})```
to
sir_data = dict(**{k:tr.trace.nodes[k]["value"].view(-1) for k in ["I_obs", "R_obs"]})
plt.xlim(start_time, end_time)
plt.xlabel("Time (Months)")
plt.ylabel("# of Individuals (Millions)")
plt.legend(loc="upper right")
There seems to be a small shape-related bug in the notebook that leads to runtime error with parallelizaton at a few locations. It remains unfixed. The locations are:
The sir observational model wasn't general enough to guide the user in building models for inference with batched data. It made a gesture in this direction by allowing:
But if a model of this type is passed to SVI inference, the log prob shapes are wrong. To ensure their correctness, we also need to introduce a plate of appropriate shape.
To illustrate and to ensure proper functionality, I added
dynamical/test_batched_inference.py
. For illustration, commenting out the lines introducing the plate in that test model (and un-indenting thepyro.sample
statements) will lead to the type of log prob shape error in question.Accordingly, I revised the dynamical systems notebook. The sir observational model now is:
Otherwise, small changes, including adding plot.show() and partial prediction parallelization:
to
to:
(added
plt.show()
at a few locations to avoid redundant printing of object names before plotting)to
and
to
There seems to be a small shape-related bug in the notebook that leads to runtime error with parallelizaton at a few locations. It remains unfixed. The locations are:
and
and
The whole notebook has been re-run.