kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning. #221

Closed samyakai closed 2 years ago

samyakai commented 2 years ago

I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jax.random' has no attribute 'KeyArray'". These are some of the specs:

OS: Ubuntu 20.04 jax version = 0.2.12 TPU : V3-8 Zone : us-central1-b

The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

This is the error stack:

WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in from optax._src.alias import adabelief File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in from optax._src import base File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in PRNGKey = jax.random.KeyArray AttributeError: module 'jax.random' has no attribute 'KeyArray'

Any help is appreciated!

mosmos6 commented 2 years ago

I've just encountered exactly the same error and I was about to open an issue about this.

KD1903 commented 2 years ago

WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in from optax._src.alias import adabelief File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in from optax._src import base File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in PRNGKey = jax.random.KeyArray AttributeError: module 'jax.random' has no attribute 'KeyArray'

I am facing the same error. Kindly solve this!

jagruti-samyak commented 2 years ago

The same error is facing while "import optax".

WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in from optax._src.alias import adabelief File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in from optax._src import base File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in PRNGKey = jax.random.KeyArray AttributeError: module 'jax.random' has no attribute 'KeyArray'

abdelatifsd commented 2 years ago

Exact same issue here!

hxiaoyang commented 2 years ago

same issue!

vfbd commented 2 years ago

Chex 0.1.3 doesn't support JAX 0.2.12. You need to downgrade to Chex 0.1.2:

pip3 install chex==0.1.2
mosmos6 commented 2 years ago

@vfbd It worked for me to infer the model. but apparently not for finetuning.

samyakai commented 2 years ago

@mosmos6 @vfbd Now it is giving me this error: "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'". @mosmos6 How did it work for you? Are you training on TPU v3-8?

mosmos6 commented 2 years ago

@samyakai Now I noticed you encountered this error on fine tune. I did on inference but the same error. Sorry for the confusion. I modified my previous comment. The issue hasn't been resolved for finetuning.

samyakai commented 2 years ago

As suggested by @vfbd if I downgrade chex to 0.1.2 I encounter " "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'".". To overcome this https://github.com/google/brax/issues/187 suggests upgrading to latest version. if I do that I again encounter the error "AttributeError: module 'jax.random' has no attribute 'KeyArray'" .

jagruti-samyak commented 2 years ago

I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice''". These are some of the specs:

OS: Ubuntu 20.04 jax version = 0.2.12 chex version == 0.1.2 TPU : V3-8 Zone : us-central1-b

The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

This is the error stack:

WARNING: Logging before InitGoogle() is written to STDERR I0421 10:06:19.047791 8679 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which pr ocess is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in < module> from optax import experimental File "/usr/local/lib/python3.8/dist-packages/optax/experimental/init.py", line 20, in from optax._src.experimental.complex_valued import split_real_and_imaginary File "/usr/local/lib/python3.8/dist-packages/optax/_src/experimental/complex_v alued.py", line 32, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, i n from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", l ine 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 40, i n CpuDevice = jax.lib.xla_extension.CpuDevice AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'

Please help us to resolve it asap.. Thank you

vfbd commented 2 years ago

Which version of jaxlib (not jax) do you have? Maybe try again with jaxlib==0.1.68

samyakai commented 2 years ago

@vfbd These are the library versions which solve the error. jax==0.2.16 jaxlib==0.1.68 optax==0.1.2 chex==0.1.2

kufton commented 1 year ago

I just paid for colab pro to play around with this and found the same issues described here. I added !pip install for the lib versions mentioned and then I got this error:

`--------------------------------------------------------------------------- ImportError Traceback (most recent call last) in 4 from jax.experimental import maps 5 import numpy as np ----> 6 import optax 7 import transformers 8

6 frames /usr/local/lib/python3.8/dist-packages/jax/_src/api.py in 42 from . import dtypes 43 from ..core import eval_jaxpr ---> 44 from ..api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs, 45 flatten_fun_nokwargs2, argnums_partial, 46 argnums_partial_except, flatten_axes, donation_vector,

ImportError: cannot import name '_ensure_str_tuple' from 'jax.api_util' (/usr/local/lib/python3.8/dist-packages/jax/api_util.py)


NOTE: If your import is failing due to a missing package, you can manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the "Open Examples" button below. ---------------------------------------------------------------------------`

I appreciate that this is like, alpha, so while I'll go play with GTP3, thank you for your work.

mystiverv commented 4 months ago

from jax_md import rigid_body File "C:\Users....\env\Lib\site-packages\jax_md\rigid_body.py", line 76, in KeyArray = random.KeyArray

module 'jax.random' has no attribute 'KeyArray'

I get this error when trying to run import jax_md