Hi, i am a beginner with jax and i was trying equinox to solve an exercise.
I was trying the introductory example "CNN on MNIST" presented on the docs website.
I copied the code on the website but it seems that the new version of JAX produce an error
ValueError: Expected None, got <PjitFunction of <function log_softmax at 0x000002675EE5A840>>.
In previous releases of JAX, flatten-up-to used to consider None to be a tree-prefix of non-None values. To obtain the previous behavior, you can usually write:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
and i have not understood what should I do to fix it.
Hi, i am a beginner with jax and i was trying equinox to solve an exercise. I was trying the introductory example "CNN on MNIST" presented on the docs website. I copied the code on the website but it seems that the new version of JAX produce an error
and i have not understood what should I do to fix it.