This small PR adds a workaround for some tensor shape errors (see #487, #480, and #468) which results from TorchDiffEq only broadcasting along a single batch dimension. We address this problem in ChiRho by collapsing all batch dimensions to a single batch dimension before running the solver, but this has strange interactions with handlers that modify tensor shapes when encountering a sample statement. This workaround replaces tensors (which are not modified by pyro.plates) with dist.Delta distributions, which are modified by plates.
This small PR adds a workaround for some tensor shape errors (see #487, #480, and #468) which results from
TorchDiffEq
only broadcasting along a single batch dimension. We address this problem in ChiRho by collapsing all batch dimensions to a single batch dimension before running the solver, but this has strange interactions with handlers that modify tensor shapes when encountering asample
statement. This workaround replaces tensors (which are not modified bypyro.plate
s) withdist.Delta
distributions, which are modified by plates.