ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

remove parallel sampling and add PDE to test #491

Closed SamWitty closed 4 months ago

SamWitty commented 4 months ago

This small PR adds a workaround for some tensor shape errors (see https://github.com/ciemss/pyciemss/issues/487, https://github.com/ciemss/pyciemss/issues/480, and https://github.com/ciemss/pyciemss/issues/468) which results from TorchDiffEq only broadcasting along a single batch dimension. We address this problem in ChiRho by collapsing all batch dimensions to a single batch dimension before running the solver, but this has strange interactions with handlers that modify tensor shapes when encountering a sample statement. To address this in the short term, this PR removes parallel=True arguments from the predictive handlers, which parallelize by making full use of tensor broadcasting. In the future, when we update to a fully broadcastable solver, this can be reverted.

See https://github.com/BasisResearch/chirho/issues/524 for some discussion.