jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

Program hang when donated input share memory with un-donated input #12627

Open imoneoi opened 2 years ago

imoneoi commented 2 years ago

Description

When a jitted function is called with a donated input and an un-donated input with the same underlying memory, the program stucks forever with 0% GPU and CPU utilization.

I think we should raise a warning or error in this situation. This is especially hard to find when these conflicting buffers hide deep in a pytree.

BTW, this may be the underlying issue of https://github.com/google/jax/issues/10737

Minimal example to reproduce:

from functools import partial
import jax
import jax.numpy as jnp

def main():
    @partial(jax.jit, donate_argnums=(0,))
    def update(x, batch):
        return x + batch

    # x takes 1GB memory
    x = jnp.zeros((1 * 1024 * 1024 * 1024 // 4), dtype=jnp.float32)

    # A training loop
    for _ in range(10):
        batch = x

        x = update(x, batch)

if __name__ == "__main__":
    main()

What jax/jaxlib version are you using?

jax v0.3.17, jaxlib v0.3.15

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] on linux

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
| N/A   39C    P0     5W /  N/A |      9MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1156      G   /usr/lib/xorg/Xorg                  4MiB |
|    0   N/A  N/A      2033      G   /usr/lib/xorg/Xorg                  4MiB |
+-----------------------------------------------------------------------------+
rajasekharporeddy commented 8 months ago

Hi @imoneoi

Looks like this issue is resolved in latest version of JAX (0.4.23) by raising an XLARunTimeError.

I reproduced the mentioned code in Google Colab. Program didn't hang instead gave the following error message.

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
<ipython-input-1-ff273c4cdc26> in <cell line: 21>()
     20 
     21 if __name__ == "__main__":
---> 22     main()

1 frames
    [... skipping hidden 10 frame]

/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py in __call__(self, *args)
   1157       self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1158     else:
-> 1159       results = self.xla_executable.execute_sharded(input_bufs)
   1160     if dispatch.needs_check_special():
   1161       out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: INVALID_ARGUMENT: Attempt to use a buffer that was previously donated in the same call to Execute() (second use: flattened argument 1, replica 0). Toy example for this bug: `f(donate(a), a)`.

Kindly find the gist.

Could you please check and confirm.

Thank you.