Open SnowOwl-Hedwig opened 1 week ago
The error reported here is actually a TypeError
being raised because of an issue with the return types in a jax.custom_jvp
. It's hard to see from this error report exactly which custom_jvp
is the culprit, but it seems like it must be something within diffrax or equinox, so I'd recommend opening the issue on the https://github.com/patrick-kidger/diffrax issue tracker.
ok, thanks for pointing this out. I'll try my luck there.
Description
Hi everyone,
based on this tutorial I tried to get started with Jax and neural ODEs: https://colab.research.google.com/drive/1ZlK36VgWy1vBjBNXjSUg6Cb-7zeoa3jh
However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument. Since this issue seemed similar to one raised in an earlier post (https://github.com/jax-ml/jax/issues/13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)
Working example:
Error message:
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34 jaxlib: 0.4.34 numpy: 1.26.4 python: 3.11.1 (tags/v3.11.1:a7a450f, Dec 6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Windows', release='10', version='10.0.19044', machine='AMD64')
jupyterlab: 4.2.2 diffrax: 0.4.1