Open elinorbgr opened 1 month ago
Actually, changing my initialization to jnp.array(0.0, dtype=np.float32)
fixes the problem. Is that the intended behaviour?
Yup, this was a breaking change in JAX 0.4.34.
Indeed explicitly creating this an array should fix the problem. :)
Alternatively I'd be happy to take a pull request adjusting this check so that it ignores weak dtypes (in the same way as #856).
From what I gathered, I assume this may be related to https://github.com/patrick-kidger/equinox/pull/856 ?
The context is, I'm working with a model with a
State
, and the particular value I'm trying to update is a scalar array, initialized to the valuejnp.array(0.0)
.My code runs fine with jax 0.4.30, but when I try to run it with some newer versions (I've tested 0.4.33 and 0.4.35), it triggers this exception as soon as it tries to update the state value:
https://github.com/patrick-kidger/equinox/blob/689c35a794f4ada89821a4220f024cb755284623/equinox/nn/_stateful.py#L175-L181
The error message I got from this was: