jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

`jnp.unique` along an axis is very slow for large arrays #17370

Open jakevdp opened 1 year ago

jakevdp commented 1 year ago

For example:

import jax.numpy as jnp
x = jnp.zeros((2, 1000))
jnp.unique(x, axis=0)  # takes ~2.5 minutes on a Colab CPU

The issue is that XLA has no unique primitive, so we're forced to implement it from scratch. unique along an axis is essentially implemened as a lexicographic sort followed by a pairwise equality check, and in this case the lex-sort splits the array columns into 1000 key arrays passed to XLA, and this has performance implications.

I brainstormed a bit with @mattjj about how we might better compute this in this case, but the answer isn't obvious.

I'm going to leave this open for now.

jakevdp commented 1 year ago

We might implement unique via spectral analysis; here's a rough example:

def unique(x):
  x = jnp.asarray(x)
  d = x.shape[0]

  # Construct adjacency matrix
  A = (x[:, None] == x[None, :]).reshape(d, d, -1).all(-1)

  # Construct Laplacian matrix & compute eigen-decomposition
  L = A - jnp.diag(A.sum(axis=1))
  evals, evecs = jnp.linalg.eigh(L)

  # Size of laplacian null space = number of unique values
  num_unique = (abs(evals) < 0.5).sum()

  # Nonzero entries in null-space eigenvectors => indices of unique values
  i = jnp.argmax(abs(evecs), axis=0)
  i = i[-num_unique:]  # Need to change this for JIT-compatibility

  # Note: return unsorted unique values because lexsort is what's causing issues!
  return x[i]

This works for 1D inputs:

x = jnp.array([1, 2, 3, 2, 3, 2, 1])
print(unique(x))
# [3 2 1]

As well as N-dimensional inputs:

x = jnp.zeros((100, 10)).at[0, 0].set(1)
print(unique(x))  # axis=0 is implied
# [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
#  [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
MoFHeka commented 5 months ago

@jakevdp Any progress? Is migrate the unique operator from TensorFlow as an independent primitive a good option? like TopK operator.

jakevdp commented 5 months ago

There hasn't been any progress on this unfortunately. As mentioned above, the unique implementation is based on XLA's variadic sort, and as the number of variables sorted becomes large, XLA slows down just as it does for any computation with many inputs & operations.

MoFHeka commented 5 months ago

@jakevdp So would it be better to abandon the XLA variadic sort and change it to a XLA custom call kernel?

jakevdp commented 5 months ago

Without more information on what exactly you have in mind, it’s hard to evaluate whether it would be better.

Can you say more about what you have in mind with this custom call solution?

MoFHeka commented 5 months ago

What about directly migrate hard codes from TensorFlow op kernel? Since the most common solution of Unique OP which is Radix sort and it’s hard to be expressed by XLA.

https://github.com/tensorflow/tensorflow/blob/abcaba2e5fce335212d1f7617bdb4e3def525c23/tensorflow/core/kernels/unique_op_gpu.cu.h

jakevdp commented 5 months ago

The algorithm implemented in the tensorflow cuda kernel (described here) is essentially identical to how it's implemented in XLA, so I'm not sure whether it would offer much room for improvement. The only difference as far as I can tell is the use of radix sort vs. whatever sort XLA is using under the hood.

Note that as far as I can tell the slowdown here doesn't have much to do with the sorting algorithm, but rather has to do with the fact that lex-sort across k keys involves splitting the input array into k input arrays, and as k gets very large, tracing and compilation becomes slow. Changing to radix sort doesn't circumvent this, because it scales linearly with k.

MoFHeka commented 4 months ago

https://github.com/c3sr/tcu_scope what about a algorithm running reduce and scan kernel in Radix sort with Nvidia Tensor Core?

unik-w commented 4 weeks ago

and in this case the lex-sort splits the array columns into 1000 key arrays passed to XLA, and this has performance implications.

Hi @jakevdp,

Does the statement above have implications for GPU memory usage during the same operation? Specifically, I'm wondering if the memory required for the implementation mentioned is 1000 times greater compared to the other two jnp.unique implementations.

I understand that the memory needed to store arrays of sizes (2, 1000) and (2, 1) is significantly different, but I'm more interested in how this impacts internal operations and GPU memory management.

jakevdp commented 4 weeks ago

I'm not sure about the implications to GPU memory use – you'd probably have to profile it if you're interested in performance at that level.