google-deepmind / chex

https://chex.readthedocs.io
Apache License 2.0
758 stars 44 forks source link

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py) #355

Open maming109 opened 3 months ago

maming109 commented 3 months ago

/tmp/ipykernel_34/2874194604.py:15: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display from IPython.core.display import display, HTML

ImportError Traceback (most recent call last) Cell In[27], line 20 18 # Import model definition from big_vision 19 from big_vision.models.proj.paligemma import paligemma ---> 20 from big_vision.trainers.proj.paligemma import predict_fns 22 # Import big vision utilities 23 import big_vision.datasets.jsonl

File /kaggle/working/big_vision_repo/big_vision/trainers/proj/paligemma/predict_fns.py:20 17 import functools 19 from big_vision.pp import registry ---> 20 import big_vision.utils as u 21 import einops 22 import jax

File /kaggle/working/big_vision_repo/big_vision/utils.py:38 36 import flax.jax_utils as flax_utils 37 import jax ---> 38 from jax.experimental.array_serialization import serialization as array_serial 39 import jax.numpy as jnp 40 import ml_collections as mlc

File /opt/conda/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py:36 34 from jax._src import sharding 35 from jax._src import sharding_impls ---> 36 from jax._src.layout import Layout, DeviceLocalLayout as DLL 37 from jax._src import typing 38 from jax._src import util

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)

jntdst commented 2 months ago

I have same error. what can we do?

stompchicken commented 2 months ago

Hi there. I don't really understand this issue, it doesn't look like you're importing chex here. Could you provide a step-by-step reproduction?