BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
164 stars 10 forks source link

Initial states don't broadcast properly in `chiro.dynamical` #524

Closed rfl-urbaniak closed 4 months ago

rfl-urbaniak commented 4 months ago

Initial states do not broadcast as expected, which leads to shape errors when using Predictive with parallel=True, or when using both pyro.plate and MultiWorldCounterfactual. A minimal failing example:

class UnifiedFixtureDynamics(pyro.nn.PyroModule):
    def __init__(self, beta=None, gamma=None):
        super().__init__()

        self.beta = beta
        self.gamma = gamma

    def forward(self, X: State[torch.Tensor]):
        dX: State[torch.Tensor] = dict()

        dX["S"] = -self.beta * X["S"] * X["I"]
        dX["I"] = self.beta * X["S"] * X["I"] - self.gamma * X["I"]  # noqa
        dX["R"] = self.gamma * X["I"]
        return dX

init_state = dict(S=torch.tensor(10.0), I=torch.tensor(3.0),
                          R=torch.tensor(3.0))

start_time = torch.tensor(0.0)
end_time = torch.tensor(10.0)
logging_times = torch.linspace(start_time + 0.01, end_time - 2, 5)

def bayesian_sir():
        with TorchDiffEq():
                    with MultiWorldCounterfactual() as cf:
                        beta = pyro.sample("beta", dist.Beta(18, 600))
                        gamma = pyro.sample("gamma", dist.Beta(1600, 1600))
                        fixture = UnifiedFixtureDynamics(beta = beta, gamma = beta)
                        cf_state = simulate(
                            fixture,
                            init_state,
                            start_time,
                            end_time,
                        )

# fails, but will not fail if num_samples = 1
predictive = Predictive(bayesian_sir, num_samples=2, parallel=True)
sir_samples = predictive()

A workaround, discovered together with @SamWitty , is to introduce the init state using pyro.sample statements, which ensures proper init state broadcasting, as follows:

def bayesian_sir():
    with LogTrajectory(
       times=logging_times, is_traced = True
    ) as dt:
        with TorchDiffEq():
            with StaticIntervention(time=intervene_time, intervention=intervene_state):
                with StaticIntervention(
                    time=intervene_time + 0.5, intervention=intervene_state
                ):
                    with MultiWorldCounterfactual() as cf:
                        beta = pyro.sample("beta", dist.Beta(18, 600))
                        gamma = pyro.sample("gamma", dist.Beta(1600, 1600))

                        # here we need to pass the init state values to Delta distributions
                        s0 = pyro.sample("S0", dist.Delta(torch.tensor(1.)))
                        i0 = pyro.sample("I0", dist.Delta(torch.tensor(5.)))
                        r0 = pyro.sample("R0", dist.Delta(torch.tensor(3.)))
                        init_state = dict(S=s0, I=i0, R=r0)

                        fixture = UnifiedFixtureDynamics(beta = beta, gamma = beta)
                        cf_state = simulate(
                            fixture,
                            init_state,
                            start_time,
                            end_time,
                        )
rfl-urbaniak commented 4 months ago

Update: SIRDynamicsLockdown also needs a workaround, dX["l"] needs to be forced to be of the right shape, instead o just being set to torch.tensor([0.0]). One way to do this:

class SIRDynamicsLockdown(SIRDynamics):
    def __init__(self, beta0, gamma):
        super().__init__(beta0, gamma)
        self.beta0 = beta0

    def forward(self, X: State[torch.Tensor]):
        self.beta = (1 - X["l"]) * self.beta0  
        dX = super().forward(X)
        dX["l"] = X["I"] * torch.tensor([0.0])   # notice multiplication here to force the right shape
        return dX