google / evojax

Apache License 2.0
826 stars 78 forks source link

How to use GPU for computing #73

Closed jinyilun718 closed 11 months ago

jinyilun718 commented 11 months ago

Dear developer, thanks for your awesome work. I have some questions. When i run the example of seq2seq, i got this warning:

Seq2seq: 2023-09-26 20:11:32,443 [INFO] ============================== WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1695730292.463880 3694542 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created. 2023-09-26 20:11:32.491777: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:276] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected 2023-09-26 20:11:32.491820: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: user-MD72-HB3-00 2023-09-26 20:11:32.491829: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: user-MD72-HB3-00 2023-09-26 20:11:32.491882: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 535.86.5 2023-09-26 20:11:32.491912: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 535.86.5 2023-09-26 20:11:32.491920: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 535.86.5 jax._src.xla_bridge: 2023-09-26 20:11:32,492 [INFO] Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. jax._src.xla_bridge: 2023-09-26 20:11:32,492 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA jax._src.xla_bridge: 2023-09-26 20:11:32,494 [INFO] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory jax._src.xla_bridge: 2023-09-26 20:11:32,494 [WARNING] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I already successfully installed GPU supported JAX, and my equipment information is LINUX with 4090 GPU and CUDA 12.2. How do i fix this problem?

jinyilun718 commented 11 months ago

Im sure that GPU version JAX is installed,since when i run this code,

_from jax.lib import xla_bridge print(xla_bridge.getbackend().platform) import jax print(jax.devices())

the follow output observed.

gpu [gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3)]