Closed wangqiwei313 closed 4 months ago
Hi @wangqiwei313,
Thank you for reporting this issue!
To address the problem, please try updating your jax
, jaxlib
, jaxopt
, and potentially other related packages. I have tested and found that using jax==0.4.25
and jax==0.4.27
with jaxlib==0.4.23
and jaxopt==0.8.3
works well.
It is challenging to pinpoint where jax.tree_util
is being called and which package might be causing this error since mellon
does not directly use this function. Could you please provide the full traceback of the error? This will help us identify the problematic package combination and might allow us to offer a more precise solution.
Thank you!
Hi katosh Following your suggestions, I just upgrade my jaxopt version to 0.40 (the older version is 0.3) and fix it. Thank you so much!
Dear authors when I run the following code: import mellon(version==1.4.3) the issue happened: module 'jax.tree_util' has no attribute 'tree_multimap' (my jax and jaxlib version==0.4.20) The stack overflow mentioned that jax.tree_multimap was deprecated in JAX version 0.3.5, and removed in JAX version 0.3.16. So I pip install jaxlib==0.3.14, but this version was removed from PyPI ERROR: Could not find a version that satisfies the requirement jaxlib==0.3.14 (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29, 0.4.30) ERROR: No matching distribution found for jaxlib==0.3.14 How to solve this issue. Thanks!