Closed bheijden closed 5 months ago
This is expected -- the output is a PyTree (of type Solution
), one of whose leaves is a jaxpr, which is not something JAX knows how to vmap. (i.e. to treat an axis of an array as a batch dimension).
The fix should be to use equinox.filter_vmap
, which is basically a wrapper to accomplish jax.vmap(..., out_axes=(<0 for all arrays, None for all non-arrays>)
.
Thanks!
Hi,
I have issues vectorizing the
optx.least_squares
function (version0.0.6
) when directly vectorized using JAX'svmap
function. This behavior occurs unless thesol.state
andsol.result
fields are removed from theSolution
dataclass instance. Perhaps related to this commit? Somehow,vmap
does not know that the jaxpr stuff should not be batched (i.e.pytree_node=False
).MWE
In the provided Minimum Working Example (MWE), I attempt to vectorize the least squares optimization using JAX's
vmap
function. The process involves a quadratic residual function and the Levenberg-Marquardt solver.The vectorization attempt fails when trying to return the full Solution object (
sol
) from theleast_squares
function. However, if onlysol.value
is returned, the vectorization succeeds. This suggests a compatibility issue between the full Solution instance structure and JAX's vmap operation.This produces the error: