BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
164 stars 10 forks source link

Fix TorchDiffEq tensor broadcasting #525

Closed SamWitty closed 4 months ago

SamWitty commented 4 months ago

This small PR addresses #524 by calling the partially evaluated deriv method a single time to determine any tensor broadcasting that occurred inside of its body. We then take the resulting tensor shapes and use them to define the broadcasted tensor shapes of the input state.

SamWitty commented 4 months ago

@eb8680 , see the state of the branch and the failing tests at 5960b4c for confirmation that this PR addresses the problem.

https://github.com/BasisResearch/chirho/actions/runs/7998151234/job/21843870134

rfl-urbaniak commented 4 months ago

Thanks @SamWitty ! Tested on my side, this seems to solve the init_state broadcasting issues with the dynamics I've been dealing with, but does not resolve broadcasting issues in general. The derivative broadcasting issue still persists and the torch.tensor([0.0]) * X["I"] workaround is still needed for parallel sampling to work with SIRDynamicsLockdown.

SamWitty commented 4 months ago

@rfl-urbaniak , thanks for pointing that out. I believe that error is a result of the model not being written in a way that is broadcastable, as the output is a tensor with shape 1 regardless of the input shape of X["l"]. Instead, we should always write these models with the following instead.

dX["l"] = torch.zeros_like(X["l"])