Closed kimbochen closed 2 years ago
I open a PR here #10 that at least make the example run. Take a look if you're interested.
Hey @lkhphuc sorry for the late replay, somehow missed this issue in my inbox.
The issue is related to a recent change in jax
(see google/jax#8017), this will be fixed in the next version of Treex. I think the easiest fix for you right now is to rollback to jax==0.2.20
as this issue also affects Flax as used by Treex.
@lkhphuc I believe this should be fixed in treex==0.5.0
Thanks for the replies, the problem is fixed.
Hi, thanks for the great work! I am trying to learn how to use JAX and treex, so I followed the tutorial.
However, I always get this assertion error.
After digging into the code, I found out that
jax.random.split(key)
seems to return keys of typenumpy.ndarray
. Replacingjnp.ndarray
withnp.ndarray
still creates problems:key
is originally of typejaxlib.xla_extension.DeviceArray
. I would love to make a PR, but I am not sure how to fix this. Here's a Colab notebook that replicates the issue.