Closed SamWitty closed 9 months ago
@eb8680 , see the state of the branch and the failing tests at 5960b4c for confirmation that this PR addresses the problem.
https://github.com/BasisResearch/chirho/actions/runs/7998151234/job/21843870134
Thanks @SamWitty ! Tested on my side, this seems to solve the init_state
broadcasting issues with the dynamics I've been dealing with, but does not resolve broadcasting issues in general. The derivative broadcasting issue still persists and the torch.tensor([0.0]) * X["I"]
workaround is still needed for parallel sampling to work with SIRDynamicsLockdown
.
@rfl-urbaniak , thanks for pointing that out. I believe that error is a result of the model not being written in a way that is broadcastable, as the output is a tensor with shape 1 regardless of the input shape of X["l"]
. Instead, we should always write these models with the following instead.
dX["l"] = torch.zeros_like(X["l"])
This small PR addresses #524 by calling the partially evaluated
deriv
method a single time to determine any tensor broadcasting that occurred inside of its body. We then take the resulting tensor shapes and use them to define the broadcasted tensor shapes of the input state.