Closed samyakai closed 2 years ago
I've just encountered exactly the same error and I was about to open an issue about this.
WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in from optax._src.alias import adabelief File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in from optax._src import base File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in PRNGKey = jax.random.KeyArray AttributeError: module 'jax.random' has no attribute 'KeyArray'
I am facing the same error. Kindly solve this!
The same error is facing while "import optax".
WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in from optax._src.alias import adabelief File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in from optax._src import base File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in PRNGKey = jax.random.KeyArray AttributeError: module 'jax.random' has no attribute 'KeyArray'
Exact same issue here!
same issue!
Chex 0.1.3 doesn't support JAX 0.2.12. You need to downgrade to Chex 0.1.2:
pip3 install chex==0.1.2
@vfbd It worked for me to infer the model. but apparently not for finetuning.
@mosmos6 @vfbd Now it is giving me this error: "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'". @mosmos6 How did it work for you? Are you training on TPU v3-8?
@samyakai Now I noticed you encountered this error on fine tune. I did on inference but the same error. Sorry for the confusion. I modified my previous comment. The issue hasn't been resolved for finetuning.
As suggested by @vfbd if I downgrade chex to 0.1.2 I encounter " "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'".". To overcome this https://github.com/google/brax/issues/187 suggests upgrading to latest version. if I do that I again encounter the error "AttributeError: module 'jax.random' has no attribute 'KeyArray'" .
I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice''". These are some of the specs:
OS: Ubuntu 20.04 jax version = 0.2.12 chex version == 0.1.2 TPU : V3-8 Zone : us-central1-b
The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".
This is the error stack:
WARNING: Logging before InitGoogle() is written to STDERR
I0421 10:06:19.047791 8679 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which pr ocess is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
Please help us to resolve it asap.. Thank you
Which version of jaxlib (not jax) do you have? Maybe try again with jaxlib==0.1.68
@vfbd These are the library versions which solve the error. jax==0.2.16 jaxlib==0.1.68 optax==0.1.2 chex==0.1.2
I just paid for colab pro to play around with this and found the same issues described here. I added !pip install for the lib versions mentioned and then I got this error:
`---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
6 frames
/usr/local/lib/python3.8/dist-packages/jax/_src/api.py in
ImportError: cannot import name '_ensure_str_tuple' from 'jax.api_util' (/usr/local/lib/python3.8/dist-packages/jax/api_util.py)
NOTE: If your import is failing due to a missing package, you can manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the "Open Examples" button below. ---------------------------------------------------------------------------`
I appreciate that this is like, alpha, so while I'll go play with GTP3, thank you for your work.
from jax_md import rigid_body
File "C:\Users....\env\Lib\site-packages\jax_md\rigid_body.py", line 76, in
I get this error when trying to run import jax_md
I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jax.random' has no attribute 'KeyArray'". These are some of the specs:
OS: Ubuntu 20.04 jax version = 0.2.12 TPU : V3-8 Zone : us-central1-b
The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".
This is the error stack:
WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'
Any help is appreciated!