patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

Issue with vmap `optx.least_squares`. #38

Closed bheijden closed 5 months ago

bheijden commented 5 months ago

Hi,

I have issues vectorizing the optx.least_squares function (version 0.0.6) when directly vectorized using JAX's vmap function. This behavior occurs unless the sol.state and sol.result fields are removed from the Solution dataclass instance. Perhaps related to this commit? Somehow, vmap does not know that the jaxpr stuff should not be batched (i.e. pytree_node=False).

MWE

In the provided Minimum Working Example (MWE), I attempt to vectorize the least squares optimization using JAX's vmap function. The process involves a quadratic residual function and the Levenberg-Marquardt solver.

The vectorization attempt fails when trying to return the full Solution object (sol) from the least_squares function. However, if only sol.value is returned, the vectorization succeeds. This suggests a compatibility issue between the full Solution instance structure and JAX's vmap operation.

import jax
import jax.numpy as jnp
import optimistix as optx

# Define a simple quadratic residual function
def residual_fn(params, *args):
    return params[0] * jnp.arange(10) ** 2 + params[1] * jnp.arange(10) + params[2]

# Initialize the Levenberg-Marquardt solver
solver = optx.LevenbergMarquardt(rtol=1e-5, atol=1e-7, norm=optx.rms_norm)

# Define the initial parameters
params_init = jnp.array([1.0, 2.0, 3.0])

# Define a function to perform least squares optimization
def least_squares(params_init):
    sol = optx.least_squares(residual_fn, solver, params_init, max_steps=100, throw=False)
    return sol  # throws an error --> returning sol.value does not... 

# Attempt to vectorize the least squares function
vmap_least_squares = jax.vmap(least_squares, in_axes=(0,), out_axes=0)

# Define a batch of initial parameters
batch_params_init = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

# Attempt to perform batched least squares optimization
batch_params_final = vmap_least_squares(batch_params_init)

This produces the error:

Traceback (most recent call last):
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-521a622dc7fd>", line 27, in <module>
    batch_params_final = vmap_least_squares(batch_params_init)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/api.py", line 1258, in vmap_f
    out_flat = batching.batch(
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/linear_util.py", line 206, in call_wrapped
    ans = gen.send(ans)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 638, in _batch_inner
    out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 270, in from_elt
    return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 1107, in matchaxis
    raise TypeError(f"Output from batched function {x!r} with type "
TypeError: Output from batched function { lambda a:f32[10] b:f32[10]; c:f32[3]. let
    d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] c
    e:f32[] = squeeze[dimensions=(0,)] d
    f:f32[10] = mul e a
    g:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] c
    h:f32[] = squeeze[dimensions=(0,)] g
    i:f32[10] = mul h b
    j:f32[10] = add f i
    k:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] c
    l:f32[] = squeeze[dimensions=(0,)] k
    m:f32[10] = add j l
  in (m,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type
patrick-kidger commented 5 months ago

This is expected -- the output is a PyTree (of type Solution), one of whose leaves is a jaxpr, which is not something JAX knows how to vmap. (i.e. to treat an axis of an array as a batch dimension).

The fix should be to use equinox.filter_vmap, which is basically a wrapper to accomplish jax.vmap(..., out_axes=(<0 for all arrays, None for all non-arrays>).

bheijden commented 5 months ago

Thanks!