Open imoneoi opened 2 years ago
Hi @imoneoi
Looks like this issue is resolved in latest version of JAX (0.4.23) by raising an XLARunTimeError
.
I reproduced the mentioned code in Google Colab. Program didn't hang instead gave the following error message.
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
<ipython-input-1-ff273c4cdc26> in <cell line: 21>()
20
21 if __name__ == "__main__":
---> 22 main()
1 frames
[... skipping hidden 10 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py in __call__(self, *args)
1157 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1158 else:
-> 1159 results = self.xla_executable.execute_sharded(input_bufs)
1160 if dispatch.needs_check_special():
1161 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: INVALID_ARGUMENT: Attempt to use a buffer that was previously donated in the same call to Execute() (second use: flattened argument 1, replica 0). Toy example for this bug: `f(donate(a), a)`.
Kindly find the gist.
Could you please check and confirm.
Thank you.
Description
When a jitted function is called with a donated input and an un-donated input with the same underlying memory, the program stucks forever with 0% GPU and CPU utilization.
I think we should raise a warning or error in this situation. This is especially hard to find when these conflicting buffers hide deep in a pytree.
BTW, this may be the underlying issue of https://github.com/google/jax/issues/10737
Minimal example to reproduce:
What jax/jaxlib version are you using?
jax v0.3.17, jaxlib v0.3.15
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] on linux
NVIDIA GPU info