google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
721 stars 120 forks source link

Jax error running on A100 GPU (everything is okay on CPU) #45

Closed ltz0120 closed 2 years ago

ltz0120 commented 2 years ago

Hi,

I got an error on the train.py, line 229 new_params, state, stats = optimizer.step(......)

The error code is shown below:

2022-04-14 12:46:59.552761: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 1 failed: INTERNAL: CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error 2022-04-14 12:47:09.554416: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2288] Replicated computation launch failed, but not all replicas terminated. Aborting process to work around deadlock. Failure message (there may have been multiple failures, see the error log for all failures):

CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error Fatal Python error: Aborted

I didn't get any error running on CPU. But on GPU I always get this error. Could you help me to solve this problem? Thank you.

ltz0120 commented 2 years ago

The version I use is Jax 0.3.5 with Jaxlib 0.3.5 with cuda11.cudnn82

jsspencer commented 2 years ago

This looks like a JAX/cuDNN install or usage issue than something specific to our code. My only suggestion is to try lowering the fraction of memory JAX preallocates (e.g. to 0.8) - https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.