Open lengstrom opened 6 months ago
I suspect that the difference is that the destination of your transfer is in CUDA pinned host memory (cudaMallocHost
), to which you can DMA directly from the device. JAX is transferring to unpinned memory. If you allocate the target buffer with malloc
in your CUDA benchmark, how do the two compare?
(We are actually working on adding support for pinned host memory allocations to JAX.)
Thanks for the quick response! With malloc
I get 6.3 GB/s throughput vs 2.7 GB/s in Jax.
Even if there is not official support, is there an easy hack to get Jax to allocate CUDA pinned memory? This problem is very important in my application + I'm only using my CPU as a staging area for GPU operations, so I am happy to have Jax only use CUDA pinned memory.
Right now, you can't hack to get pinned_host working. The implementation is missing. We are working on it.
I might be able to get you the 6.3GB/s without much trouble, though, if that's helpful.
Another workaround for the moment would be to use DLPack to exchange the on-GPU array with another library that already supports pinned host memory (e.g., cupy) and use that library to do the transfer.
Thank you for the suggestions - bridging to cupy worked!
Any progress on the missing pinned host implementation? Is the following related: "TODO(b/238441608): Use pinned memory here to speed up the transfer." from py_client_gpu.cc? I see some code under gpu_transfer_manager referring to "Check out pinned memory for each buffer we want to copy" under GpuTransferManager::ReadDynamicShapes. Do you have a design about H2D D2H D2D memcpy ? Thank you.
Description
For D2H (gpu to cpu) transfer,
jax.device_put
has very low throughput.device_put
yields ~2.7GB/s transfer speed; in contrast, a very simple CUDA program yields ~25GB/s. Is there an alternative approach in Jax here that I'm missing?I tried the following two approaches as well, both performed at least as poorly as
jax.device_put
:xc.batched_device_put
as detailed in https://github.com/google/jax/issues/16905#issue-1829138543)Minimal Jax example:
with output
And here is a simple CUDA program for copying:
With output:
System info (python version, jaxlib version, accelerator, etc.)