jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.65k stars 2.82k forks source link

Surprisingly slow jax.lax.top_k #9940

Open felixchalumeau opened 2 years ago

felixchalumeau commented 2 years ago

Hi,

I was writing some functions with jax.lax.top_k and saw that it was particularly slow. I would expect a top_k to be quite fast as it just requires to go once through the data at end, is that right?

To get a better idea of it's performance compared to other operations, I launched a quick benchmark (code snipped at the end of the issue).

I compared jax.lax.top_k (and jax.lax.approx_max_k) to jnp.argmax and jnp.argsort. It used k=10 and a data of size (256, 80000) (this arbitrary choice comes from the type of sizes I am using in my applications). I am running those on a GPU T4.

What I observed is that:

Observing such a difference with jnp.argmax, I decided to implement a naïve top_k, using argmax, which is hence faster than top_k (and approx_max_k) for small values of k (with my data size). And indeed, in my use case, this simple implementation (provided in the code snipped below, named naive_top_k) is 7 times faster than top_k (0.04 vs 0.0056).

Is this a known behavior? Is there a reason for this?

Would be very interested in some insight about this (and a potential fix in the future :slightly_smiling_face:)!

import time

import jax
import jax.numpy as jnp
import numpy as np

if __name__ == "__main__":

    # create a random key
    key = jax.random.PRNGKey(seed=0)

    # create some random data
    data = jax.random.uniform(key, shape=(256, 80000)).block_until_ready()

    # create jitted functions with fixed k
    k = 10

    @jax.jit
    def jitted_top_k(data):
        values, _indices = jax.lax.top_k(data, k=10)
        return values

    @jax.jit
    def jitted_approx_max_k(data):
        values, _indices = jax.lax.approx_max_k(data, k=10)
        return values

    jitted_argmax = jax.jit(jnp.argmax)
    jitted_argsort = jax.jit(jnp.argsort)

    # Let's benchmark those functions

    N = 20
    M = 5  # avoid taking the first times

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_top_k(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jax.lax.top_k : ", np.mean(times))

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_approx_max_k(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jax.lax.approx_max_k : ", np.mean(times))

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_argsort(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jnp.argsort : ", np.mean(times))

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_argmax(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jnp.argmax : ", np.mean(times))

    # Surprisingly, top_k and approx_max_k are almost
    # as fast as argsort

    # Surpringly, top_k and approx_max_k are much
    # slower that argmax

    # Let's build a top_k with argmax

    def naive_top_k(data, k):
        """Top k implementation built with argmax.
        Faster for smaller k."""

        def top_1(data):
            indice = jnp.argmax(data, axis=1)
            value = jax.vmap(lambda x, y: x[y])(data, indice)
            data = jax.vmap(lambda x, y: x.at[y].set(-jnp.inf))(data, indice)
            return data, value, indice

        def scannable_top_1(carry, unused):
            data = carry
            data, value, indice = top_1(data)
            return data, (value, indice)

        data, (values, indices) = jax.lax.scan(scannable_top_1, data, (), k)

        return values.T, indices.T

    @jax.jit
    def jitted_naive_top_k(data):
        values, _indices = naive_top_k(data, k=10)
        return values

    # benchmark our new top k
    times = []
    for i in range(N):
        start_time = time.time()
        jitted_naive_top_k(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for naive top k : ", np.mean(times))

My output with a GPU T4:

Time for jax.lax.top_k :  0.04046181042989095
Time for jax.lax.approx_max_k :  0.02177149454752604
Time for jnp.argsort :  0.0582158088684082
Time for jnp.argmax :  0.0005081971486409505
Time for naive top k :  0.005642461776733399

Thanks!

jakevdp commented 2 years ago

Thanks for the report – I believe this issue is currently being looked at (see e.g. 1d5833d2f15fe81d8866f4b5481e364262a6cb04). I'm going to assign to @LenaMartens because I think this is related to that work.

felixchalumeau commented 2 years ago

Thanks for the quick reply! Great!

felixchalumeau commented 2 years ago

Hey! Any follow-up on this issue? Or any insight about what the problem is (or was)? @LenaMartens Thanks!

LenaMartens commented 2 years ago

Hi, sorry for the slow response! I lost track of this.

I've been trying to improve the performance of top_k specifically for inputs which have more than 2 dimensions, which as I understand from your benchmark does not apply here. (Still trying to land that >2 dimension change, it got rolled back: https://github.com/google/jax/commit/c3a4a6e63da11246611247feac7ff4c00750ae21 due to an unrelated bug which it's blocked on)

Unfortunately I'm not much of a top_k performance expert beyond that change: from asking around I understand that top_k on GPU could be better, and there's some JAX/XLA improvements to be made there (@dryman can correct me if I'm wrong). Not sure what exactly we can do to improve it, but I can have a closer look at your benchmark now that I'm trying to land this other >2D top_k change.

Maybe we can keep this issue to keep track of top_k performance on GPU?

dryman commented 2 years ago

We implemented jax.lax.approx_max_k for TPU to address the slowness of top-k. (Up to 148x faster in some scenarios.)

Maybe we can prioritize to implement similar algorithms for GPU in the next quarter.

felixchalumeau commented 2 years ago

Hey!

Thanks for the follow-ups! Sorry for not having reacted quicker to your messages.

Do you know if any improvement of the top-k/max-k functions has been made recently? Or if any has been prioritised in the future? Thanks! :slightly_smiling_face: @LenaMartens @dryman

felixchalumeau commented 2 years ago

Hi there :wave:

Any news @LenaMartens @dryman?

timellemeet commented 9 months ago

Has any progress been made since? Is it now advised to use the approximated version when using a GPU?

james77777778 commented 8 months ago

I can confirm that jax.lax.top_k and jax.lax.approx_max_k are both slower on GPU (RTX 4070) compared to tf.math.top_k (both jitted)

Is there an alternative solution available, or do we need to get a custom implementation similar to what @felixchalumeau did?

jon-chuang commented 6 months ago

Hello @dryman , do you know what it takes to lower approx_max_k (aka https://arxiv.org/pdf/2206.14286) through Pallas to TPU? Can it be implemented via rudimentary jax.lax ops?

Notably, jax.lax.argmax is also not present on Pallas. Since tpu_custom_call takes a private (i.e. Google only) code path in XLA there is no way of knowing if it takes a custom compilation pathway separate from XLA ops.

For a reminder, here is the algorithm used: image

Further, how can I implement bitonic sort on TPU?

cheshire commented 6 months ago

For GPU side, could you file a bug on https://github.com/openxla/xla ? We use a custom CUDA kernel for topk, so I would've expected it to be quite fast (+ a heuristic to switch between sort and topk)