Open felixchalumeau opened 2 years ago
Thanks for the report – I believe this issue is currently being looked at (see e.g. 1d5833d2f15fe81d8866f4b5481e364262a6cb04). I'm going to assign to @LenaMartens because I think this is related to that work.
Thanks for the quick reply! Great!
Hey! Any follow-up on this issue? Or any insight about what the problem is (or was)? @LenaMartens Thanks!
Hi, sorry for the slow response! I lost track of this.
I've been trying to improve the performance of top_k
specifically for inputs which have more than 2 dimensions, which as I understand from your benchmark does not apply here. (Still trying to land that >2 dimension change, it got rolled back: https://github.com/google/jax/commit/c3a4a6e63da11246611247feac7ff4c00750ae21 due to an unrelated bug which it's blocked on)
Unfortunately I'm not much of a top_k
performance expert beyond that change: from asking around I understand that top_k
on GPU could be better, and there's some JAX/XLA improvements to be made there (@dryman can correct me if I'm wrong). Not sure what exactly we can do to improve it, but I can have a closer look at your benchmark now that I'm trying to land this other >2D top_k
change.
Maybe we can keep this issue to keep track of top_k
performance on GPU?
We implemented jax.lax.approx_max_k
for TPU to address the slowness of top-k. (Up to 148x faster in some scenarios.)
Maybe we can prioritize to implement similar algorithms for GPU in the next quarter.
Hey!
Thanks for the follow-ups! Sorry for not having reacted quicker to your messages.
Do you know if any improvement of the top-k/max-k functions has been made recently? Or if any has been prioritised in the future? Thanks! :slightly_smiling_face: @LenaMartens @dryman
Hi there :wave:
Any news @LenaMartens @dryman?
Has any progress been made since? Is it now advised to use the approximated version when using a GPU?
I can confirm that jax.lax.top_k
and jax.lax.approx_max_k
are both slower on GPU (RTX 4070) compared to tf.math.top_k
(both jitted)
Is there an alternative solution available, or do we need to get a custom implementation similar to what @felixchalumeau did?
Hello @dryman , do you know what it takes to lower approx_max_k
(aka https://arxiv.org/pdf/2206.14286) through Pallas to TPU? Can it be implemented via rudimentary jax.lax
ops?
Notably, jax.lax.argmax
is also not present on Pallas. Since tpu_custom_call
takes a private (i.e. Google only) code path in XLA there is no way of knowing if it takes a custom compilation pathway separate from XLA ops.
For a reminder, here is the algorithm used:
Further, how can I implement bitonic sort on TPU?
For GPU side, could you file a bug on https://github.com/openxla/xla ? We use a custom CUDA kernel for topk, so I would've expected it to be quite fast (+ a heuristic to switch between sort and topk)
Hi,
I was writing some functions with
jax.lax.top_k
and saw that it was particularly slow. I would expect atop_k
to be quite fast as it just requires to go once through the data at end, is that right?To get a better idea of it's performance compared to other operations, I launched a quick benchmark (code snipped at the end of the issue).
I compared
jax.lax.top_k
(andjax.lax.approx_max_k
) tojnp.argmax
andjnp.argsort
. It used k=10 and a data of size (256, 80000) (this arbitrary choice comes from the type of sizes I am using in my applications). I am running those on a GPU T4.What I observed is that:
jax.lax.top_k
is almost as fast asjnp.argsort
for those parameters (0.04 vs 0.058) (I expected it to be much faster)jax.lax.top_k
is much slower thanjnp.argmax
(0.04 vs 0.0005) (I expected them to be almost as fast)Observing such a difference with
jnp.argmax
, I decided to implement a naïve top_k, using argmax, which is hence faster than top_k (and approx_max_k) for small values of k (with my data size). And indeed, in my use case, this simple implementation (provided in the code snipped below, namednaive_top_k
) is 7 times faster than top_k (0.04 vs 0.0056).Is this a known behavior? Is there a reason for this?
Would be very interested in some insight about this (and a potential fix in the future :slightly_smiling_face:)!
My output with a GPU T4:
Thanks!