This small PR adds a workaround for some tensor shape errors (see https://github.com/ciemss/pyciemss/issues/487, https://github.com/ciemss/pyciemss/issues/480, and https://github.com/ciemss/pyciemss/issues/468) which results from TorchDiffEq only broadcasting along a single batch dimension. We address this problem in ChiRho by collapsing all batch dimensions to a single batch dimension before running the solver, but this has strange interactions with handlers that modify tensor shapes when encountering a sample statement. To address this in the short term, this PR removes parallel=True arguments from the predictive handlers, which parallelize by making full use of tensor broadcasting. In the future, when we update to a fully broadcastable solver, this can be reverted.
This small PR adds a workaround for some tensor shape errors (see https://github.com/ciemss/pyciemss/issues/487, https://github.com/ciemss/pyciemss/issues/480, and https://github.com/ciemss/pyciemss/issues/468) which results from TorchDiffEq only broadcasting along a single batch dimension. We address this problem in ChiRho by collapsing all batch dimensions to a single batch dimension before running the solver, but this has strange interactions with handlers that modify tensor shapes when encountering a sample statement. To address this in the short term, this PR removes
parallel=True
arguments from the predictive handlers, which parallelize by making full use of tensor broadcasting. In the future, when we update to a fully broadcastable solver, this can be reverted.See https://github.com/BasisResearch/chirho/issues/524 for some discussion.