patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
326 stars 14 forks source link

FutureWarning (unhashable type) thrown with LM solver #81

Open djbower opened 2 months ago

djbower commented 2 months ago

Just FYI, a future fix may be warranted. You can reproduce with the root finding example by swapping out the solver for the LM solver. Running with JAX 0.4.31 and Python 3.10.13.

import jax
import jax.numpy as jnp
import optimistix as optx

# Often import when doing scientific work
jax.config.update("jax_enable_x64", True)

def fn(y, args):
    a, b = y
    c = jnp.tanh(jnp.sum(b)) - a
    d = a**2 - jnp.sinh(b + 1)
    return c, d

solver = optx.LevenbergMarquardt(rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.0), jnp.zeros((2, 2)))
sol = optx.root_find(fn, solver, y0)

Trace:

Package initialized with double precision (float64)
Warning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
Category: FutureWarning
File: /Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py, Line: 691
Stack trace:
  File "/Users/dan/Documents/academic/explanetology/atmodeller/scripts/jax_CHO_low_temperature.py", line 439, in <module>
    raise_warning()
  File "/Users/dan/Documents/academic/explanetology/atmodeller/scripts/jax_CHO_low_temperature.py", line 432, in raise_warning
    sol = optx.root_find(fn, solver, y0)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 242, in __call__
    return self._call(False, args, kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 215, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
    p, args_flat = _infer_params(fun, jit_info, args, kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 736, in _infer_params
    p, args_flat = _infer_params_impl(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 633, in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 352, in memoized_fun
    ans = call(fun, *args)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1277, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2355, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 51, in fun_wrapped
    out = fun(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_root_find.py", line 194, in root_find
    return least_squares(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 242, in __call__
    return self._call(False, args, kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 215, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
    p, args_flat = _infer_params(fun, jit_info, args, kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 736, in _infer_params
    p, args_flat = _infer_params_impl(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 633, in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 352, in memoized_fun
    ans = call(fun, *args)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1277, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2355, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 51, in fun_wrapped
    out = fun(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_least_squares.py", line 119, in least_squares
    return iterative_solve(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_iterate.py", line 334, in iterative_solve
    ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_adjoint.py", line 134, in apply
    return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_ad.py", line 59, in implicit_jvp
    root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_ad.py", line 788, in __call__
    return self.fn(static, dynamic)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 261, in __call__
    out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 361, in bind
    outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers,
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2148, in process_custom_jvp_call
    fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_ad.py", line 745, in fn_wrapper
    return fn(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_ad.py", line 66, in _implicit_impl
    return jtu.tree_map(jnp.asarray, fn_primal(inputs))
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_iterate.py", line 242, in _iterate
    final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/internal/_loop/loop.py", line 103, in while_loop
    _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1334, in while_loop
    init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1317, in _create_jaxpr
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 67, in _initial_style_jaxpr
    jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_open_jaxpr
    jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2355, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/internal/_loop/common.py", line 463, in new_body_fun
    buffer_val2 = body_fun(buffer_val)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_iterate.py", line 232, in body_fun
    new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_root_find.py", line 65, in step
    new_y, new_state, (f, aux) = self.solver.step(fn, y, args, options, state, tags)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_solver/gauss_newton.py", line 339, in step
    y_descent, descent_result = self.descent.step(step_size, descent_state)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_solver/levenberg_marquardt.py", line 130, in step
    sol_value, result = damped_newton_step(
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_solver/levenberg_marquardt.py", line 57, in damped_newton_step
    lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(step_size).max)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/ml_dtypes/_finfo.py", line 414, in __new__
    return super().__new__(cls, dtype)
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/numpy/core/getlimits.py", line 486, in __new__
    obj = cls._finfo_cache.get(dtype)  # most common path
  File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 691, in __hash__
    warnings.warn(
  File "/Users/dan/Programs/anaconda3/envs/py310/lib/python3.10/warnings.py", line 109, in _showwarnmsg
    sw(msg.message, msg.category, msg.filename, msg.lineno,
  File "/Users/dan/Documents/academic/explanetology/atmodeller/scripts/jax_CHO_low_temperature.py", line 63, in warning_handler
    traceback.print_stack()
patrick-kidger commented 2 months ago

Thanks for the report! This should already be fixed in #61, but we haven't done a new release yet :)