Closed ltz0120 closed 2 years ago
The version I use is Jax 0.3.5 with Jaxlib 0.3.5 with cuda11.cudnn82
This looks like a JAX/cuDNN install or usage issue than something specific to our code. My only suggestion is to try lowering the fraction of memory JAX preallocates (e.g. to 0.8) - https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.
Hi,
I got an error on the train.py, line 229 new_params, state, stats = optimizer.step(......)
The error code is shown below:
2022-04-14 12:46:59.552761: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 1 failed: INTERNAL: CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error 2022-04-14 12:47:09.554416: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2288] Replicated computation launch failed, but not all replicas terminated. Aborting process to work around deadlock. Failure message (there may have been multiple failures, see the error log for all failures):
CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error Fatal Python error: Aborted
I didn't get any error running on CPU. But on GPU I always get this error. Could you help me to solve this problem? Thank you.