jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

[FEA] Implement cupy.cuda.set_allocator to allow share memory allocator between different libraries #5179

Open miguelusque opened 3 years ago

miguelusque commented 3 years ago

Hi,

GPU-accelerated libraries usually pre-allocate an amount of GPU memory in order to increase their performance.

That is great if those libraries are used in isolation, but when we need to combine different libraries, chances are that we get an OOM because different memory managers are fighting for the available GPU memory.

In order to solve the issue described above, some libraries have adopted a mechanism to share a unique memory manager among them.

For instance, CuPy and RAPIDS cuDF provides _setallocator() method to specify which memory manager is being used by these libraries. Numba also supports it.

It would be great if you could provide the same mechanism in JAX. I am aware you already provide a to_dlpack() method, which will present the data in a format familiar to many libraries, like cuDF and CuPy. I think that being able to share a memory manager would facilitate JAX integration with other GPU-Accelerated libraries.

Thanks for considering the above.

Regards, Miguel

gravitino commented 3 years ago

I support this request

Zenodia commented 3 years ago

+1 support this request as well

gfiameni commented 3 years ago

+1

skye commented 3 years ago

For my understanding, what other GPU libraries would you like to run with JAX?

One way I can think to implement this would be to provide a way to pass in a custom allocator instead of using the hardcoded one here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/gpu_device.cc#L323. Then we could plumb through a user-defined Python function (?) from JAX that defines how to allocate/deallocate, and wrap that in the interface the C++ code expects.

This sounds doable but a decent-sized project, so I can't commit to any timeline for implementing this unfortunately. @hawkinsp do you have any further thoughts or ideas?

miguelusque commented 3 years ago

Hi @skye ,

Thank you for considering this Feature Request.

I think we might find multiple workloads that might benefit from the integration with RAPIDS libraries (cudf, cuml, cugraph...).

For instance, NLP workloads like the following would benefit from GPU accelerated preprocessing and tokenization, provided by cudf and cuml (https://medium.com/rapids-ai/state-of-the-art-nlp-at-scale-with-rapids-huggingface-and-dask-a885c19ce87b).

Regards, Miguel

pwuertz commented 1 year ago

Hi @skye

For my understanding, what other GPU libraries would you like to run with JAX?

I'm also using a combination of Cuda libraries like Numba, CuPy and TensorRT, wondering if it's possible to fit Jax into the mix. Sharing device arrays is pretty easy thanks to CUDA Array Interface. But being able to coordinate memory allocations is an important aspect when memory pooling libraries like CuPy or Jax are involved.

Luckily, user defined memory allocators are somewhat of a reoccurring pattern with basically the same mechanics. For reference:

TL;DR: +1 :)

miguelusque commented 1 year ago

Hi all, btw, please let me mention that PyTorch has recently added support to Rapids Memory Manager (RMM), as an external memory allocator).

Therefore, IIRC, so far RMM is supported by Dask, RAPIDS libraries (cuDF, cuML, cuGraph...), PyTorch, Numba and CuPy.

It would be great to add JAX (and TensorFlow) to the list! :-)

jakirkham commented 1 year ago

For my understanding, what other GPU libraries would you like to run with JAX?

I'm also using a combination of Cuda libraries like Numba, CuPy and TensorRT, wondering if it's possible to fit Jax into the mix. Sharing device arrays is pretty easy thanks to CUDA Array Interface. But being able to coordinate memory allocations is an important aspect when memory pooling libraries like CuPy or Jax are involved.

Also worth noting issue ( https://github.com/google/jax/issues/1100 ) where some additional functionality is still needed for __cuda_array_interface__ support

pwuertz commented 1 year ago

If the implementation of custom memory allocators in XLA is indeed a decent-sized project as @skye mentioned and won't be happening in the foreseeable future, maybe jax could allow limited access to the existing XLA allocator via python?

If we can't adapt memory management in jax, this would give us the option to use the fixed xla allocator within other libraries, and thus enabling some basic coexistence.

I was thinking about some proof of principle by just using jnp.empty as an allocator, but according to the docs, jax doesn't support uninitialized arrays and is always going to zero-initialize them.. which is of course highly undesirable for an allocator :/

So basically, a feature request for "np.empty" and/or "low level access to xla allocator". Would this be easier to do in xla?