jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.07k stars 2.75k forks source link

`lax.scan` on `map_coordinates` slower on GPU than on CPU? #10794

Open regevs opened 2 years ago

regevs commented 2 years ago

I am running the following (simplified for demonstration) code:

import numpy as np
import time

import jax.numpy as jnp
import jax, jax.lax, jax.scipy.ndimage, jax.tree_util

@jax.jit
def flow_field_at(x, y, flow_field):
    new_x = jax.scipy.ndimage.map_coordinates(
        input = flow_field[0, :, :],
        coordinates = jnp.array([x, y]),
        order = 1,           # Linear interpolation
        mode = "constant",   # Clip at edges
    )

    new_y = jax.scipy.ndimage.map_coordinates(
        input = flow_field[1, :, :],
        coordinates = jnp.array([x, y]),
        order = 1,           # Linear interpolation
        mode = "constant",   # Clip at edges
    )

    return jnp.array([new_x, new_y])

@jax.jit
def forward(segments, flow_field):
    init_coords = jnp.array([0.0, 0.0])    

    def forward_scan_function(coords, segment_type):
        res = flow_field_at(
            coords[0], coords[1],
            flow_field    
            )

        return res, res

    _, output = jax.lax.scan(
        f = forward_scan_function,
        init = init_coords,
        xs = segments
        )

    return output

if __name__ == '__main__':

    flow_field = np.ones((2, 51, 50))

    vectorized_forward = jax.vmap(
        jax.tree_util.Partial(
            forward, 
            flow_field = flow_field
            )
        )    

    all_segments = jnp.ones((800, 500000))
    all_segments = jax.device_put(all_segments)

    # Run twice to separate compile and run time
    for run in range(2):
        start_time = time.time()

        res = vectorized_forward(all_segments).block_until_ready()

        end_time = time.time()
        print(f"Done ({end_time - start_time:1.5f} seconds).")

When running on a CPU, I get:

2022-05-21 23:45:24.962224: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Done (0.30775 seconds).
Done (0.14178 seconds).

But on a GPU, I get:

Done (6.45742 seconds).
Done (5.82286 seconds).

How come? What could I do to speed it on a GPU? Happy for any other jax tips. Many thanks.

I have seen elsewhere that XLA by default presents all the cores as one device and uses the cores for its own intra-op parallelism. Is that controllable? I am running on a multi-CPU computer - how can I check if this the cause?

jakevdp commented 2 years ago

Generally you'll find that scan and similar serial control flow operations show poor performance on GPU – this is because in general each step of the scan depends on the last, and so they cannot be done in parallel. CPUs are hardare designed to be fast on this kind of serial operation, but not necessarily fast for parallel vector math. GPUs are chips that are designed to be fast for parallel vector math, but not particularly fast on serial operations. So funamentally, you have an algorithm that's not well-suited for execution on GPU.

There's another piece here, but I'll admit I'm a bit more murky on this: I believe that the scan implementation on GPU requires host synchronization at every step, which adds additional overhead that scales as the number of steps. Someone else might be able to confirm that and/or give more details on that piece.

patrick-kidger commented 2 years ago

I'd also be curious to know about host sync! I've heard this on the grapevine a few times and would love to know more.

jakevdp commented 2 years ago

@hawkinsp might be able to enlighten us here.

regevs commented 2 years ago

I realized that by default jax utilizes all my CPUs. When I force it to use a single CPU (using taskset -c 0 mpiexec -np 1 as described in #1539), the result is still more or less the same. On CPU:

Done (0.86165 seconds).
Done (0.71001 seconds).

On GPU:

Done (6.85969 seconds).
Done (5.89742 seconds).

However, in line with @jakevdp's comment on host synchronization, I do see that if the data shape is (500000, 800) instead of (800, 500000), GPU is much faster, ie, for CPU:

Done (0.73948 seconds).
Done (0.59321 seconds).

while GPU:

Done (0.24141 seconds).
Done (0.01284 seconds).

So, a x46 speedup excluding complication time. This is very useful knowledge for my application - thanks. I am therefore very curious to learn about host synchronization in lax.scan, as this seems to be a limiting performance factor.