google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.06k stars 2.66k forks source link

Jax scans are slower than expected #2491

Open dionhaefner opened 4 years ago

dionhaefner commented 4 years ago

I am implementing the tridiagonal matrix algorithm (TDMA) to solve many tridiagonal systems of the same shape in two sweeps (one forward and one backward pass).

The shape of each diagonal is something like (100_000, 100), and I vectorize over the leading axis, so this should be reasonably efficient.

In pure NumPy, I would do it like this:

def tdma_naive(a, b, c, d):
    """
    Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.
    """
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    n = a.shape[-1]

    for i in range(1, n):
        w = a[..., i] / b[..., i - 1]
        b[..., i] += -w * c[..., i - 1]
        d[..., i] += -w * d[..., i - 1]

    out = np.empty_like(a)
    out[..., -1] = d[..., -1] / b[..., -1]

    for i in range(n - 2, -1, -1):
        out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]

    return out

The JAX implementation looks like this:

def tdma_jax_kernel(a, b, c, d):
    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x

        denom = 1. / (b - a * last_cp)
        cp = c * denom
        dp = (d - a * last_dp) * denom

        new_primes = (cp, dp)
        return new_primes, new_primes

    diags = (a.T, b.T, c.T, d.T)
    init = jnp.zeros((a.shape[1], a.shape[0]))
    _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))

    return sol[::-1].T

I implemented the algorithm in a handful of backends (including a sloppily written CUDA kernel). You can see the results in this Gist:

https://gist.github.com/dionhaefner/a97ef80b77e02b36e4b248bb97541161

The executive summary is that Jax is 2.5x slower than Numba on CPU, and 3x slower than my amateurish CUDA kernel on GPU (but is on par with Numba here).

If I eliminate the tranposes from the Jax implementation and transpose the inputs beforehand, the implementation gains a factor 2 of performance on GPU, so it would be nice if scan supported scanning over arbitrary axes.

Is this behavior something that is expected, and is there something else I can do to make the Jax implementation more efficient?

mattjj commented 4 years ago

Thanks for raising this, and for the beautifully clear report! Let's make it faster.

I wouldn't be surprised if XLA:TPU eliminates these transposes (though I admit I'm not certain if it does). I pinged the XLA:CPU/GPU folks for some initial thoughts on this issue.

It sounds like we can solve this on the JAX side. Knowing that manual transposing helps gives us a lot of information; as you suggest, we can make scan work over arbitrary axes, though the bookkeeping may get annoying.

Do you need to differentiate through this? I'm guessing for now at least the answer may be "no", in which case perhaps you can try using a fori_loop or while_loop with indexing instead of a scan? That would also make it a bit closer to your NumPy code IIUC. WDYT?

shoyer commented 4 years ago

+1 for trying fori/while_loop here. For gradients, you probably want to use the implicit function theorem rather than differentiating through the solve — take a look at lax.custom_linear_solve.

dionhaefner commented 4 years ago

Well, here goes my fori_loop implementation:

def tdma_jax_kernel_fori(a, b, c, d):
    def compute_primes(i, val):
        b, d = val
        w = a[..., i] / b[..., i - 1]
        b = jax.ops.index_add(b, jax.ops.index[..., i], -w * c[..., i - 1])
        d = jax.ops.index_add(d, jax.ops.index[..., i], -w * d[..., i - 1])
        return (b, d)

    n = a.shape[2]

    b, d = jax.lax.fori_loop(1, n, compute_primes, (b, d))

    def backsubstitution(ir, sol):
        i = n - ir - 1
        sol = jax.ops.index_update(sol, jax.ops.index[..., i], (d[..., i] - c[..., i] * sol[..., i+1]) / b[..., i])
        return sol

    sol = jnp.empty_like(a)
    sol = jax.ops.index_update(sol, jax.ops.index[..., -1], d[..., -1] / b[..., -1])
    sol = jax.lax.fori_loop(1, n, backsubstitution, sol)

    return sol

tdma_jax_fori = jax.jit(tdma_jax_kernel_fori, backend='cpu')
tdma_jax_cuda_fori = jax.jit(tdma_jax_kernel_fori, backend='gpu')

Timing CPU:

4.56 s ± 147 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

(~15x slower than before)

Timing GPU:

801 ms ± 68.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

(~60x slower than before)

Did I do anything wrong?

mattjj commented 4 years ago

@dionhaefner are the output values the same between the new implementation and the old?

There must be something weird going on because scan lowers to a loop that looks just like that, modulo using lax.dynamic_slice and lax.dynamic_update_slice (through convenience wrappers like lax.dynamic_index_in_dim lax.dynamic_update_index_in_dim) rather than the lax.scatter that jax.ops.index_update generates. We've seen cases where lax.dynamic_slice and lax.dynamic_update_slice are faster but this seems pretty extreme. It's the only thing I can think of, though...

@hawkinsp any guesses?

dionhaefner commented 4 years ago

Yes, outputs are the same (and this is almost verbatim the NumPy implementation).

I double-checked my code, and found a small problem when calling the GPU version of the function. The official slowdown on GPU is now a factor of 20, not 80:

272 ms ± 15.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

This is also more in line with the CPU slowdown.

mattjj commented 4 years ago

I'm currently wrestling with the fallout of at least one big bug in #2026, so unfortunately I can't dig into this right away. One quick thing I might try next is switching to using dynamic_index_in_dim / dynaminc_update_index_in_dim to see if that closes the gap with scan. (If it did, I'd consider it a bug in index_update.) Then I'd want to follow up on some ideas the XLA folks suggested, and then consider the longer-term solution of #2509.

bnwebcode commented 4 years ago

I'm looking into lax.scan performance in another context and I came across this issue.

I wanted to repeat the benchmarking in this issue on Google Colab to understand the issue better. I used the code from @dionhaefner original gist. However I find that on Colab, the JAX GPU version of the tridiagonal solver is about 6 times slower than measurements by @dionhaefner . The CPU performance measurements are similar.

My notebook is at: https://colab.research.google.com/drive/1-Vblr_7qQd1SR7QkPdMC3Z0jHyv1OJUB?usp=sharing

Probably just user error (Colab is not suitable for this benchmarking?) or has there been a performance regression? Is there maybe a Jax performance regression test suite?

dionhaefner commented 4 years ago

That doesn’t really surprise me; I tested on a Tesla P100 which is quite a bit beefier than Colab’s K80.

jakevdp commented 4 years ago

Colab will give you one of several GPU types depending on availability; you can run !cat /var/colab/hostname to quickly see the type of GPU backend you were assigned (P100 is one of the options available).

bnwebcode commented 4 years ago

OK, thanks for that information. I reconnected a few times and hit upon a P100 (I guess I should sign-up for colab pro...), and I measure 64ms, so around 4.5 times slower than in @dionhaefner's original gist. There does seem to be a discrepancy but I'm not sure if it isn't user error on my part or an actual performance regression.

Justin-Tan commented 4 years ago

Observing similar behaviour with jax.lax.scan when trying to implement the logarithm of the cumulative sum of exponentials using a prefix scan. Here's a MWI:

import jax
import jax.numpy as jnp
from jax import random
from jax import jit, 
from jax.scipy.special import logsumexp

import os, time

def lax_logcumsumexp(x):
    def _logaddexpcarry(carry, x):
        out = jnp.logaddexp(carry, x)
        return out, out
    cumsum_final, _logcumsumexp = jax.lax.scan(_logaddexpcarry, init=-jnp.inf, xs=x)
    return _logcumsumexp

def logcumsumexp(x):
    x_max = jnp.sort(x)[-1] 
    return x_max + jnp.log(jnp.cumsum(jnp.exp(x-x_max)))

# Running below block multiple times - not just first JITing
test_big = random.normal(random.PRNGKey(int(time.time())), (100000,))
%time jit(lax_logcumsumexp)(test_big).block_until_ready()
# About 8-9 ms on CPU, 1.5-2s on GPU
%time jit(logcumsumexp)(test_big).block_until_ready()
# About 75-90 ms on CPU, 600-700 µs on GPU

It's strange how lax.scan is significantly faster on CPU to GPU even with jit - I'm not sure if I'm doing something verboten here. If it matters, I was using a P100 for the GPU (non-Colab). I also tested it in Colab, with a similar Slowdown.

shoyer commented 4 years ago

XLA GPU currently always executes dynamic control flow on the CPU. So small loop iterations (like what you have here) end up much slower, due to the need to frequently synchronize between the CPU/GPU.

You can find a similar example of this sort of slow down in https://github.com/google/jax/pull/3076

clemisch commented 4 years ago

I'm sorry to chime in on this old issue; maybe this is better suited as a new one.

I was under the impression that using jax.lax.scan eliminates "dynamic" control flow because everything is known at compile time. In this repro, what has to be transfered to CPU?

mattjj commented 4 years ago

@clemisch Indeed the control flow is staged entirely out to XLA and the trip count is known at compile time, so at the XLA HLO level we've completely eliminated dynamic control flow. But there's more to the story on GPUs.

Warning: I'm not an expert! This is just my best understanding and I hope others will correct me / fill in the gaps where I mess things up.

An XLA:GPU program is itself ultimately lowered to a hybrid CPU/GPU program, where the GPU parts are kernels (a mix of CUDA / cuDNN kernels and XLA-codegenned ones) and the CPU parts just handle launching the kernels and perhaps other runtime calls. ~(It's all compiled, i.e. the CPU part isn't interpreted like TF or PyTorch, which is why the CPU-side overheads can be much lower as in the NumPyro benchmarks, though for some workloads that doesn't make any difference.)~ (EDIT: removed because IMO it's impossible to define "interpreted" vs "compiled" precisely.)

In this case, jax.lax.scan generates a single XLA HLO While loop with a fixed trip count, so how does that get turned into such a hybrid CPU/GPU program? Perhaps the best thing XLA:GPU could do would be to lower it into a single kernel, since that would minimize overheads and maximize optimization opportunities. But XLA:GPU can't (yet) generate a single kernel for whole loops. Instead, the loop has both CPU and GPU parts. The second best thing we could hope XLA:GPU would do is generate a single GPU kernel for the loop body, so that the CPU part of the program would just be a CPU loop with a fixed trip count launching those kernels, and we'd only pay one launch overhead cost per iteration. Unfortunately, XLA:GPU often has to generate multiple kernels for the loop body, meaning we pay several kernel launch overheads per iteration. (Moreover, I believe that sometimes, but not always, XLA:GPU may generate extra copy operations for the loop carry.)

The upshot of all this is that XLA:GPU doesn't (yet) do the best with some loops. There could be some fundamental limits based on the GPU programming model or the tools NVIDIA provides for generating GPU programs, but I suspect we're not at those limits yet and more can be done with more investment in XLA:GPU. So the best policy is to send love and support towards XLA:GPU developers (both on Google compiler teams and in open source, including at NVIDIA) so we can make this thing we love even better!

(One reason I'm optimistic for the future here is from seeing what XLA:TPU can do, since it's the most developed XLA backend. With XLA:TPU, the whole program is staged out to the TPU, including the control flow for scans and other loops, so things like kernel launch overheads just don't exist.)

Does that make sense?

clemisch commented 4 years ago

Wow, thank you for the very detailed answer @mattjj !

After reading your explanation, it makes sense that multiple kernels are built and are controlled from the CPU, creating some overhead. In "day-to-day" JAX use on GPUs, I should not be worried about that I guess.

joglekara commented 4 years ago

To echo @clemisch , many many thanks for that answer @mattjj . Really helpful to have that kind of insight !

joglekara commented 3 years ago

Perhaps this should go to a different issue but I guess I'll start here.

Since we're talking about launching CUDA/cuDNN kernels, is there any desire in implementing a wrapper to https://docs.nvidia.com/cuda/cusparse/index.html#gtsv and/or https://docs.nvidia.com/cuda/cusparse/index.html#batch_gtsv?

One reason why I think this may be out of scope is because it's not part of numpy and I realize that we're mostly interested in reproducing numpy functionality in jax. Also, it's not exactly CPU-friendly (but I'm almost sure y'all have tricks to handle that case.)

On the plus side, it will provide a nice clean solution to the ubiquitous tridiagonal inversion problem!

mattjj commented 3 years ago

@joglekara great question! I'd say it's certainly worth considering, though we probably want an overall plan for sparse matrix support. JAX implements the NumPy API because it's the best Python API for dense array operations, but that doesn't mean we need to be constrained by it. After all, we already have different APIs for things like convolutions, indexed update operators, and random sampling.

There are a lot of questions involved in sparse matrix support, including what API to provide (scipy.sparse or something else?) and how exactly to integrate it into JAX. Maybe a discussion thread would be a good place to brainstorm.

I suspect that in the near term the best approach may be to build these things in libraries on top of JAX, which can include things like cusparse calls, and then we can figure out whether/how to upstream parts into JAX proper.

cc @jakevdp

joglekara commented 3 years ago

Makes sense. Glad I wasn't totally off-base. Thanks for the prompt response, @mattjj

jcrichard commented 3 years ago

I also tried to optimize the log-likelihood of a Kalman filter, so I do need the differentiation of scan. I also observe the same type of performance issue when switching from CPU to GPU. The longer the filter, the worst it is on GPU. I was hoping to optimize large KF using Jax. Is it expected? I also use jax.cond and jax.ops.index_add.