jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.79k forks source link

D2H (gpu -> cpu) transfer via `device_put` is very slow #21438

Open lengstrom opened 5 months ago

lengstrom commented 5 months ago

Description

For D2H (gpu to cpu) transfer, jax.device_put has very low throughput. device_put yields ~2.7GB/s transfer speed; in contrast, a very simple CUDA program yields ~25GB/s. Is there an alternative approach in Jax here that I'm missing?

I tried the following two approaches as well, both performed at least as poorly as jax.device_put:

Minimal Jax example:

import time
import jax
import jax.numpy as jnp

cpu_dev = jax.devices('cpu')[0]

GBs = 10
def big_tensor():
    return jnp.ones((GBs * 1024**3 // 4,), dtype=jnp.float32, device=jax.devices('gpu')[0])

def test_transfer(x):
    jax.block_until_ready(x)
    s = time.time()

    out = jax.device_put(x, cpu_dev)

    jax.block_until_ready(out)
    dur = time.time() - s
    print(f"Time taken: {dur}; gbps: {GBs/dur}")

for i in range(5):
    test_transfer(big_tensor())

with output

Time taken: 3.7397873401641846; gbps: 2.6739488346310565
Time taken: 3.6008787155151367; gbps: 2.7770999220031807
Time taken: 3.677137613296509; gbps: 2.7195065976970936
Time taken: 3.594850778579712; gbps: 2.781756633567665
Time taken: 3.5868709087371826; gbps: 2.787945330187717

And here is a simple CUDA program for copying:

#include <chrono>
#include <cstdio>
#include <cuda_runtime.h>

const uint64_t num_bytes = 10ul * (1ul << 30);  // 10 GB
const uint64_t num_floats = num_bytes / sizeof(float);

int main() {
    cudaStream_t stream0;
    cudaStreamCreate(&stream0);

    float* host_a;
    cudaMallocHost(&host_a, num_floats * sizeof(float));

    float* device_a;
    cudaMalloc(&device_a, num_floats * sizeof(float));

    auto start = std::chrono::high_resolution_clock::now();

    cudaMemcpyAsync(device_a, host_a, num_floats * sizeof(float), cudaMemcpyHostToDevice, stream0);
    cudaStreamSynchronize(stream0);

    auto end = std::chrono::high_resolution_clock::now();
    printf("Time: %f\n", std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() / 1000.0);

    cudaFreeHost(host_a);
    cudaFree(device_a);
    cudaStreamDestroy(stream0);

    return 0;
}

With output:

Time: 0.401000

System info (python version, jaxlib version, accelerator, etc.)

jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='deep-h-3.csail.mit.edu', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')

$ nvidia-smi
Sun May 26 22:16:33 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 PCIe               On  |   00000000:01:00.0 Off |                    0 |
| N/A   82C    P0            340W /  350W |   61910MiB /  81559MiB |     79%      Default |
|                                         |                        |             Disabled |
hawkinsp commented 5 months ago

I suspect that the difference is that the destination of your transfer is in CUDA pinned host memory (cudaMallocHost), to which you can DMA directly from the device. JAX is transferring to unpinned memory. If you allocate the target buffer with malloc in your CUDA benchmark, how do the two compare?

(We are actually working on adding support for pinned host memory allocations to JAX.)

lengstrom commented 5 months ago

Thanks for the quick response! With malloc I get 6.3 GB/s throughput vs 2.7 GB/s in Jax.

Even if there is not official support, is there an easy hack to get Jax to allocate CUDA pinned memory? This problem is very important in my application + I'm only using my CPU as a staging area for GPU operations, so I am happy to have Jax only use CUDA pinned memory.

nouiz commented 5 months ago

Right now, you can't hack to get pinned_host working. The implementation is missing. We are working on it.

hawkinsp commented 5 months ago

I might be able to get you the 6.3GB/s without much trouble, though, if that's helpful.

hawkinsp commented 5 months ago

Another workaround for the moment would be to use DLPack to exchange the on-GPU array with another library that already supports pinned host memory (e.g., cupy) and use that library to do the transfer.

lengstrom commented 5 months ago

Thank you for the suggestions - bridging to cupy worked!