Just FYI, a future fix may be warranted. You can reproduce with the root finding example by swapping out the solver for the LM solver. Running with JAX 0.4.31 and Python 3.10.13.
import jax
import jax.numpy as jnp
import optimistix as optx
# Often import when doing scientific work
jax.config.update("jax_enable_x64", True)
def fn(y, args):
a, b = y
c = jnp.tanh(jnp.sum(b)) - a
d = a**2 - jnp.sinh(b + 1)
return c, d
solver = optx.LevenbergMarquardt(rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.0), jnp.zeros((2, 2)))
sol = optx.root_find(fn, solver, y0)
Trace:
Package initialized with double precision (float64)
Warning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
Category: FutureWarning
File: /Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py, Line: 691
Stack trace:
File "/Users/dan/Documents/academic/explanetology/atmodeller/scripts/jax_CHO_low_temperature.py", line 439, in <module>
raise_warning()
File "/Users/dan/Documents/academic/explanetology/atmodeller/scripts/jax_CHO_low_temperature.py", line 432, in raise_warning
sol = optx.root_find(fn, solver, y0)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 242, in __call__
return self._call(False, args, kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 215, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 332, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 736, in _infer_params
p, args_flat = _infer_params_impl(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 633, in _infer_params_impl
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 352, in memoized_fun
ans = call(fun, *args)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1277, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2355, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 51, in fun_wrapped
out = fun(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_root_find.py", line 194, in root_find
return least_squares(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 242, in __call__
return self._call(False, args, kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 215, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 332, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 736, in _infer_params
p, args_flat = _infer_params_impl(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 633, in _infer_params_impl
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 352, in memoized_fun
ans = call(fun, *args)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1277, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2355, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_jit.py", line 51, in fun_wrapped
out = fun(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_least_squares.py", line 119, in least_squares
return iterative_solve(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_iterate.py", line 334, in iterative_solve
) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_adjoint.py", line 134, in apply
return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_ad.py", line 59, in implicit_jvp
root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_ad.py", line 788, in __call__
return self.fn(static, dynamic)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 261, in __call__
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 361, in bind
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers,
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2148, in process_custom_jvp_call
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_ad.py", line 745, in fn_wrapper
return fn(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_ad.py", line 66, in _implicit_impl
return jtu.tree_map(jnp.asarray, fn_primal(inputs))
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_iterate.py", line 242, in _iterate
final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/internal/_loop/loop.py", line 103, in while_loop
_, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1334, in while_loop
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1317, in _create_jaxpr
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 67, in _initial_style_jaxpr
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_open_jaxpr
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2355, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2378, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/internal/_loop/common.py", line 463, in new_body_fun
buffer_val2 = body_fun(buffer_val)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_iterate.py", line 232, in body_fun
new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_root_find.py", line 65, in step
new_y, new_state, (f, aux) = self.solver.step(fn, y, args, options, state, tags)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_solver/gauss_newton.py", line 339, in step
y_descent, descent_result = self.descent.step(step_size, descent_state)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_solver/levenberg_marquardt.py", line 130, in step
sol_value, result = damped_newton_step(
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/optimistix/_solver/levenberg_marquardt.py", line 57, in damped_newton_step
lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(step_size).max)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/ml_dtypes/_finfo.py", line 414, in __new__
return super().__new__(cls, dtype)
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/numpy/core/getlimits.py", line 486, in __new__
obj = cls._finfo_cache.get(dtype) # most common path
File "/Users/dan/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 691, in __hash__
warnings.warn(
File "/Users/dan/Programs/anaconda3/envs/py310/lib/python3.10/warnings.py", line 109, in _showwarnmsg
sw(msg.message, msg.category, msg.filename, msg.lineno,
File "/Users/dan/Documents/academic/explanetology/atmodeller/scripts/jax_CHO_low_temperature.py", line 63, in warning_handler
traceback.print_stack()
Just FYI, a future fix may be warranted. You can reproduce with the root finding example by swapping out the solver for the LM solver. Running with JAX 0.4.31 and Python 3.10.13.
Trace: