ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

Fix errors with tensor shape broadcasting #488

Closed SamWitty closed 4 months ago

SamWitty commented 4 months ago

This small PR adds a workaround for some tensor shape errors (see #487, #480, and #468) which results from TorchDiffEq only broadcasting along a single batch dimension. We address this problem in ChiRho by collapsing all batch dimensions to a single batch dimension before running the solver, but this has strange interactions with handlers that modify tensor shapes when encountering a sample statement. This workaround replaces tensors (which are not modified by pyro.plates) with dist.Delta distributions, which are modified by plates.

SamWitty commented 4 months ago

Closing, as these explorations didn't bear fruit...