patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

issue with the introductory example "CNN on MNIST" #880

Closed pasq-cat closed 1 month ago

pasq-cat commented 1 month ago

Hi, i am a beginner with jax and i was trying equinox to solve an exercise. I was trying the introductory example "CNN on MNIST" presented on the docs website. I copied the code on the website but it seems that the new version of JAX produce an error

(...) 21 y: Int[Array, " batch"], 22 ): 23 loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) ---> 24 updates, opt_state = optim.update(grads, opt_state, model) 25 model = eqx.apply_updates(model, updates) 26 return model, opt_state, loss_value ...

ValueError: Expected None, got <PjitFunction of <function log_softmax at 0x000002675EE5A840>>.

In previous releases of JAX, flatten-up-to used to consider None to be a tree-prefix of non-None values. To obtain the previous behavior, you can usually write: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

and i have not understood what should I do to fix it.

patrick-kidger commented 1 month ago

Thank you for the report! Indeed, looks like this broke with recent JAX versions. This should now be fixed as of #881. :)