Open SnowOwl-Hedwig opened 1 week ago
This is a known issue that arose in JAX 0.4.34. The tangent types of integers in custom autodiff was changed from matching the primal to instead being a float0
.
I've updated Equinox to be compatible in https://github.com/patrick-kidger/equinox/pull/871. I'll do a new release soon. In the mean time you can either install Equinox directly from HEAD, or you can downgrade to JAX 0.4.33.
I hope that helps! :)
(I can see that you said you already tried downgrading. I have just double-checked and Equinox v0.11.7 + JAX 0.4.33 works for me, so I think something else has probably gone wrong for you there. :) )
Don't know what I'm doing wrong here. I just tried equinox 0.11.7 and jax 0.4.33 and still the same issue. Maybe the new release will help ... Fortunately it's not urgend :)
Just want to note that I get a similar problem
../sdist/amici/jax.py:123: in _solve
sol = diffrax.diffeqsolve(
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1272: in _stop_gradient_on_unperturbed_jvp
perturb_val = _resolve_perturb_val(
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1241: in _resolve_perturb_val
perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1214: in _resolve_perturb_val_impl
jax.linearize(_to_linearize, dynamic)
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1207: in _to_linearize
_out = _body_fun(_val)
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1272: in _stop_gradient_on_unperturbed_jvp
perturb_val = _resolve_perturb_val(
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1241: in _resolve_perturb_val
perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1214: in _resolve_perturb_val_impl
jax.linearize(_to_linearize, dynamic)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_dynamic = (_ClosureConvert(
jaxpr=None,
consts=[
None,
None,
None,
f64[],
None,
None,
None,
...m(
_value=None,
_enumeration=<class 'optimistix._solution.RESULTS'>
),
step=None
), unused, (None,), ...)))
def _to_linearize(_dynamic):
_body_fun, _val = combine(_dynamic, static)
> _out = _body_fun(_val)
E TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
E primal int64[] with tangent int64[], expecting tangent ShapedArray(float0[])
E primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
E primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
E primal int64[] with tangent int64[], expecting tangent ShapedArray(float0[])
E --------------------
E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
with jax==0.4.34,jaxlib==0.4.34,diffrax==0.6.0,equinox==0.11.8 on python 3.13.
Issue above is fixed with optimistix 0.0.9 :)
Hi,
based on this tutorial I tried to get started with Jax and neural ODEs: https://colab.research.google.com/drive/1ZlK36VgWy1vBjBNXjSUg6Cb-7zeoa3jh
However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument. Since this issue seemed similar to one raised in an earlier post (https://github.com/jax-ml/jax/issues/13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)
(Even though it's labeled JaxStack... error, @dfm pointed out it might actually be a problem with diffrax: "The error reported here is actually a TypeError being raised because of an issue with the return types in a jax.custom_jvp. It's hard to see from this error report exactly which custom_jvp is the culprit, but it seems like it must be something within diffrax or equinox, so I'd recommend opening the issue on the https://github.com/patrick-kidger/diffrax issue tracker." https://github.com/jax-ml/jax/issues/24253)
Working example:
Error message:
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34 jaxlib: 0.4.34 numpy: 1.26.4 python: 3.11.1 (tags/v3.11.1:a7a450f, Dec 6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Windows', release='10', version='10.0.19044', machine='AMD64')
jupyterlab: 4.2.2 diffrax: 0.4.1