patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Neural ODE Example Fails on Jax v0.4.33 #503

Closed allen-adastra closed 1 month ago

allen-adastra commented 1 month ago

The basic example fails on v0.4.33: https://docs.kidger.site/diffrax/examples/neural_ode/

we get a ValueError

ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided.

It seems the issue is a couple weak_type show up in _check_closure_convert_input

> venv/lib/python3.11/site-packages/equinox/_ad.py(447)_check_closure_convert_input()

(Pdb) in_dynamic_struct
(((ShapeDtypeStruct(shape=(), dtype=int32), ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=bool), State(
  y=ShapeDtypeStruct(shape=(2,), dtype=float32),
  tprev=ShapeDtypeStruct(shape=(), dtype=float32),
  tnext=ShapeDtypeStruct(shape=(), dtype=float32),
  made_jump=ShapeDtypeStruct(shape=(), dtype=bool),
  solver_state=(
    ShapeDtypeStruct(shape=(), dtype=bool),
    ShapeDtypeStruct(shape=(2,), dtype=float32)
  ),
  controller_state=(
    ShapeDtypeStruct(shape=(), dtype=bool),
    ShapeDtypeStruct(shape=(), dtype=bool),
    ShapeDtypeStruct(shape=(), dtype=float32),
    ShapeDtypeStruct(shape=(), dtype=float32),
    ShapeDtypeStruct(shape=(), dtype=float32)
  ),
  result=EnumerationItem(
    _value=ShapeDtypeStruct(shape=(), dtype=int32),
    _enumeration=<class 'diffrax._solution.RESULTS'>
  ),
  num_steps=ShapeDtypeStruct(shape=(), dtype=int32),
  num_accepted_steps=ShapeDtypeStruct(shape=(), dtype=int32),
  num_rejected_steps=ShapeDtypeStruct(shape=(), dtype=int32),
  save_state=SaveState(
    saveat_ts_index=ShapeDtypeStruct(shape=(), dtype=int32),
    ts=ShapeDtypeStruct(shape=(10,), dtype=float32),
    ys=ShapeDtypeStruct(shape=(10, 2), dtype=float32),
    save_index=ShapeDtypeStruct(shape=(), dtype=int32)
  ),
  dense_ts=None,
  dense_infos=None,
  dense_save_index=None
)),), {})
(Pdb) self_in_dynamic_struct
(((ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True), ShapeDtypeStruct(shape=(), dtype=bool), ShapeDtypeStruct(shape=(), dtype=bool), State(
  y=ShapeDtypeStruct(shape=(2,), dtype=float32),
  tprev=ShapeDtypeStruct(shape=(), dtype=float32),
  tnext=ShapeDtypeStruct(shape=(), dtype=float32),
  made_jump=ShapeDtypeStruct(shape=(), dtype=bool),
  solver_state=(
    ShapeDtypeStruct(shape=(), dtype=bool),
    ShapeDtypeStruct(shape=(2,), dtype=float32)
  ),
  controller_state=(
    ShapeDtypeStruct(shape=(), dtype=bool),
    ShapeDtypeStruct(shape=(), dtype=bool),
    ShapeDtypeStruct(shape=(), dtype=float32),
    ShapeDtypeStruct(shape=(), dtype=float32),
    ShapeDtypeStruct(shape=(), dtype=float32)
  ),
  result=EnumerationItem(
    _value=ShapeDtypeStruct(shape=(), dtype=int32),
    _enumeration=<class 'diffrax._solution.RESULTS'>
  ),
  num_steps=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True),
  num_accepted_steps=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True),
  num_rejected_steps=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True),
  save_state=SaveState(
    saveat_ts_index=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True),
    ts=ShapeDtypeStruct(shape=(10,), dtype=float32),
    ys=ShapeDtypeStruct(shape=(10, 2), dtype=float32),
    save_index=ShapeDtypeStruct(shape=(), dtype=int32, weak_type=True)
  ),
  dense_ts=None,
  dense_infos=None,
  dense_save_index=None
)),), {})
(Pdb) 
patrick-kidger commented 1 month ago

Try upgrading to Equinox 0.11.7, which should fix this.

(JAX made a breaking change in 0.4.32, which we needed to update to work around.)

allen-adastra commented 1 month ago

Problem solved; thanks!