Closed rfl-urbaniak closed 4 months ago
Update: SIRDynamicsLockdown
also needs a workaround, dX["l"]
needs to be forced to be of the right shape, instead o just being set to torch.tensor([0.0])
. One way to do this:
class SIRDynamicsLockdown(SIRDynamics):
def __init__(self, beta0, gamma):
super().__init__(beta0, gamma)
self.beta0 = beta0
def forward(self, X: State[torch.Tensor]):
self.beta = (1 - X["l"]) * self.beta0
dX = super().forward(X)
dX["l"] = X["I"] * torch.tensor([0.0]) # notice multiplication here to force the right shape
return dX
Initial states do not broadcast as expected, which leads to shape errors when using
Predictive
withparallel=True
, or when using bothpyro.plate
andMultiWorldCounterfactual
. A minimal failing example:A workaround, discovered together with @SamWitty , is to introduce the init state using
pyro.sample
statements, which ensures proper init state broadcasting, as follows: