Closed jnibauer closed 9 months ago
Yup, you absolutely can. The trick is either to call .evaluate
either
(a) whilst still within the jax.vmap
, or
(b) to pass sol_batch
back through a jax.vmap
decorator: jax.vmap(lambda s: s.evaluate(0.5))(sol_batch)
.
I hope that helps!
Perfect! Thanks!
Hey @patrick-kidger, I've been using diffrax for (among many reasons) the support of dense solutions. One issue I've been running into is how to evaluate a batched dense solution. That is, if I solve the same ODE for many different initial conditions using vmap, with dense = True, does diffrax support the evaluation of a batched dense solution?
Here's a simple example to demonstrate:
The behavior I would like is if sol_batch.evaluate(0.5) returns an array of length 100 x 6, representing the interpolated solution to the initial value problem at t=0.5 for the set of 100 initial conditions. Instead, I get the broadcasting error
TypeError: lt got incompatible shapes for broadcasting: (100,), (65537,).
This is likely the intended behavior, though I wanted to check if there is another way to evaluate a (batched) dense solution in this case. Thanks!