patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

ValueError when updating a State value from "weak_f32" to "f32" with jax >= 0.4.33 #885

Open elinorbgr opened 1 month ago

elinorbgr commented 1 month ago

From what I gathered, I assume this may be related to https://github.com/patrick-kidger/equinox/pull/856 ?

The context is, I'm working with a model with a State, and the particular value I'm trying to update is a scalar array, initialized to the value jnp.array(0.0).

My code runs fine with jax 0.4.30, but when I try to run it with some newer versions (I've tested 0.4.33 and 0.4.35), it triggers this exception as soon as it tries to update the state value:

https://github.com/patrick-kidger/equinox/blob/689c35a794f4ada89821a4220f024cb755284623/equinox/nn/_stateful.py#L175-L181

The error message I got from this was:

ValueError: Old and new values have different structures/shapes/dtypes. The old value is weak_f32[] and the new value is f32[].
elinorbgr commented 1 month ago

Actually, changing my initialization to jnp.array(0.0, dtype=np.float32) fixes the problem. Is that the intended behaviour?

patrick-kidger commented 4 weeks ago

Yup, this was a breaking change in JAX 0.4.34.

Indeed explicitly creating this an array should fix the problem. :)

Alternatively I'd be happy to take a pull request adjusting this check so that it ignores weak dtypes (in the same way as #856).