google-research / long-range-arena

Long Range Arena for Benchmarking Efficient Transformers
Apache License 2.0
710 stars 77 forks source link

jax report that "No GPU/TPU found, falling back to CPU" #28

Closed La-SilverLand closed 3 years ago

La-SilverLand commented 3 years ago

in the requirement.txt it requires jax>=0.2.4 when incurring this problem as in the title, the jax github homepages says for both support of GPU and CPU, one needs to install as

pip install --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html The jaxlib version must correspond to the version of the existing CUDA installation you want to use

but that jaxlib repo link does not contain any version >= 0.2.4, actually the newest version there is 0.1.67

How do you install the environment ?

La-SilverLand commented 3 years ago

I got it fixed, the jax>=0.2.4 does not mean jaxlib>=0.2.4 for my situation, i have cuda 10.2 and jax==0.2.13 jaxlib==0.1.65+cuda102 is suitable

DaShenZi721 commented 1 year ago

@La-SilverLand Hello! Sorry to bother you. Have you ever encountered the following problem? I think it may be related to the version of flax.

Traceback (most recent call last):
  File "lra_benchmarks/listops/train.py", line 28, in <module>
    from flax.deprecated import nn
ModuleNotFoundError: No module named 'flax.deprecated'
DaShenZi721 commented 1 year ago

@La-SilverLand This is my setting: