Closed allen-adastra closed 1 month ago
The basic example fails on v0.4.33: https://docs.kidger.site/diffrax/examples/neural_ode/
we get a ValueError
ValueError
ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided.
It seems the issue is a couple weak_type show up in _check_closure_convert_input
weak_type
_check_closure_convert_input
> venv/lib/python3.11/site-packages/equinox/_ad.py(447)_check_closure_convert_input() (Pdb) in_dynamic_struct (((ShapeDtypeStruct(shape=(), dtype=int32), ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=bool), State( y=ShapeDtypeStruct(shape=(2,), dtype=float32), tprev=ShapeDtypeStruct(shape=(), dtype=float32), tnext=ShapeDtypeStruct(shape=(), dtype=float32), made_jump=ShapeDtypeStruct(shape=(), dtype=bool), solver_state=( ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(2,), dtype=float32) ), controller_state=( ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=float32), ShapeDtypeStruct(shape=(), dtype=float32), ShapeDtypeStruct(shape=(), dtype=float32) ), result=EnumerationItem( _value=ShapeDtypeStruct(shape=(), dtype=int32), _enumeration=<class 'diffrax._solution.RESULTS'> ), num_steps=ShapeDtypeStruct(shape=(), dtype=int32), num_accepted_steps=ShapeDtypeStruct(shape=(), dtype=int32), num_rejected_steps=ShapeDtypeStruct(shape=(), dtype=int32), save_state=SaveState( saveat_ts_index=ShapeDtypeStruct(shape=(), dtype=int32), ts=ShapeDtypeStruct(shape=(10,), dtype=float32), ys=ShapeDtypeStruct(shape=(10, 2), dtype=float32), save_index=ShapeDtypeStruct(shape=(), dtype=int32) ), dense_ts=None, dense_infos=None, dense_save_index=None )),), {}) (Pdb) self_in_dynamic_struct (((ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True), ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=bool), State( y=ShapeDtypeStruct(shape=(2,), dtype=float32), tprev=ShapeDtypeStruct(shape=(), dtype=float32), tnext=ShapeDtypeStruct(shape=(), dtype=float32), made_jump=ShapeDtypeStruct(shape=(), dtype=bool), solver_state=( ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(2,), dtype=float32) ), controller_state=( ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=float32), ShapeDtypeStruct(shape=(), dtype=float32), ShapeDtypeStruct(shape=(), dtype=float32) ), result=EnumerationItem( _value=ShapeDtypeStruct(shape=(), dtype=int32), _enumeration=<class 'diffrax._solution.RESULTS'> ), num_steps=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True), num_accepted_steps=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True), num_rejected_steps=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True), save_state=SaveState( saveat_ts_index=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True), ts=ShapeDtypeStruct(shape=(10,), dtype=float32), ys=ShapeDtypeStruct(shape=(10, 2), dtype=float32), save_index=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True) ), dense_ts=None, dense_infos=None, dense_save_index=None )),), {}) (Pdb)
Try upgrading to Equinox 0.11.7, which should fix this.
(JAX made a breaking change in 0.4.32, which we needed to update to work around.)
Problem solved; thanks!
The basic example fails on v0.4.33: https://docs.kidger.site/diffrax/examples/neural_ode/
we get a
ValueError
It seems the issue is a couple
weak_type
show up in_check_closure_convert_input