NbAiLab / notram

Norwegian Transformer Model
Apache License 2.0
114 stars 6 forks source link

cannot import name 'mesh' trying to train GTP Neo #3

Closed AliNajafi1998 closed 2 years ago

AliNajafi1998 commented 2 years ago

Hi, I am trying to train GPT Neo on TPU based on the guide section you are providing: configure_flax.md

However, I get the following error:

File "run_clm_mp.py", line 41, in <module> from jax.experimental.maps import mesh ImportError: cannot import name 'mesh' from 'jax.experimental.maps' (/home/ali.najafi/flax/lib/python3.8/site-packages/jax/experimental/maps.py)

lib version:

flax==0.6.1
importlib-resources==5.10.0
jax==0.3.24
jaxlib==0.3.24
keras==2.10.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.4
libclang==14.0.6
libtpu-nightly==0.1.dev20221103
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib==3.6.2
ml-collections==0.1.1
msgpack==1.0.4
multidict==6.0.2
multiprocess==0.70.14
numpy==1.23.4
oauthlib==3.2.2
opt-einsum==3.3.0
optax==0.1.3
packaging==21.3
pandas==1.5.1
Pillow==9.3.0
promise==2.3
protobuf==3.19.6
pyarrow==10.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.13.0
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2022.6
PyYAML==6.0
regex==2022.10.31
requests==2.28.1
requests-oauthlib==1.3.1
responses==0.18.0
rich==12.6.0
rsa==4.9
scipy==1.9.3
six==1.16.0
tensorboard==2.10.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.10.0
tensorflow-cpu==2.10.0
tensorflow-datasets==4.7.0
tensorflow-estimator==2.10.0
tensorflow-io-gcs-filesystem==0.27.0
tensorflow-metadata==1.10.0
termcolor==2.1.0
tokenizers==0.13.1
toml==0.10.2
toolz==0.12.0
tqdm==4.64.1
transformers==4.25.0.dev0
typing-extensions==4.4.0
urllib3==1.26.12
Werkzeug==2.2.2
wrapt==1.14.1
xxhash==3.1.0
yarl==1.8.1
zipp==3.10.0
AliNajafi1998 commented 2 years ago

So I installed flax==0.4.0 and the issues got resolved! :D