Closed ParticularlyPythonicBS closed 1 month ago
This looks like the same issue that in equinox that came from the new weak_type struct in jax 0.4.33 (see https://github.com/patrick-kidger/equinox/issues/854, https://github.com/google/jax/issues/23690).
With diffrax 0.6.0, equinox 0.11.6, and jax 0.4.31, it works.
Thank you so much, freezing the jax version fixes it for now. Hope the upstream issue is fixed soon.
Closing as fixed in Equinox v0.11.7 / https://github.com/patrick-kidger/equinox/pull/856 ! Thanks for the report :)
Closing since fixed and last comment intended to close it
Using diffrax ode integration within a equinox nn training loop throws the error:
ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided: ...
traceback.txt
stderr attached since its too long to paste into the issue.
on equinox version 0.11.6, jax version 0.4.33, and diffrax 0.6.0 while it works perfectly fine on equinox version 0.10.6, jax version 0.4.13, and diffrax 0.4.0.
Here is an MVE that replicates the traceback provided:
As far as I'm aware there have been no deprecation warnings for any of this code.
Is there a better way to perform this task where an equinox neural network gives the argument for the function to be integrated using diffrax?