Open jakevdp opened 1 year ago
We might implement unique
via spectral analysis; here's a rough example:
def unique(x):
x = jnp.asarray(x)
d = x.shape[0]
# Construct adjacency matrix
A = (x[:, None] == x[None, :]).reshape(d, d, -1).all(-1)
# Construct Laplacian matrix & compute eigen-decomposition
L = A - jnp.diag(A.sum(axis=1))
evals, evecs = jnp.linalg.eigh(L)
# Size of laplacian null space = number of unique values
num_unique = (abs(evals) < 0.5).sum()
# Nonzero entries in null-space eigenvectors => indices of unique values
i = jnp.argmax(abs(evecs), axis=0)
i = i[-num_unique:] # Need to change this for JIT-compatibility
# Note: return unsorted unique values because lexsort is what's causing issues!
return x[i]
This works for 1D inputs:
x = jnp.array([1, 2, 3, 2, 3, 2, 1])
print(unique(x))
# [3 2 1]
As well as N-dimensional inputs:
x = jnp.zeros((100, 10)).at[0, 0].set(1)
print(unique(x)) # axis=0 is implied
# [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
@jakevdp Any progress? Is migrate the unique operator from TensorFlow as an independent primitive a good option? like TopK operator.
There hasn't been any progress on this unfortunately. As mentioned above, the unique
implementation is based on XLA's variadic sort, and as the number of variables sorted becomes large, XLA slows down just as it does for any computation with many inputs & operations.
@jakevdp So would it be better to abandon the XLA variadic sort and change it to a XLA custom call kernel?
Without more information on what exactly you have in mind, it’s hard to evaluate whether it would be better.
Can you say more about what you have in mind with this custom call solution?
What about directly migrate hard codes from TensorFlow op kernel? Since the most common solution of Unique OP which is Radix sort and it’s hard to be expressed by XLA.
The algorithm implemented in the tensorflow cuda kernel (described here) is essentially identical to how it's implemented in XLA, so I'm not sure whether it would offer much room for improvement. The only difference as far as I can tell is the use of radix sort vs. whatever sort XLA is using under the hood.
Note that as far as I can tell the slowdown here doesn't have much to do with the sorting algorithm, but rather has to do with the fact that lex-sort across k
keys involves splitting the input array into k
input arrays, and as k
gets very large, tracing and compilation becomes slow. Changing to radix sort doesn't circumvent this, because it scales linearly with k
.
https://github.com/c3sr/tcu_scope what about a algorithm running reduce and scan kernel in Radix sort with Nvidia Tensor Core?
and in this case the lex-sort splits the array columns into 1000 key arrays passed to XLA, and this has performance implications.
Hi @jakevdp,
Does the statement above have implications for GPU memory usage during the same operation? Specifically, I'm wondering if the memory required for the implementation mentioned is 1000 times greater compared to the other two jnp.unique implementations.
I understand that the memory needed to store arrays of sizes (2, 1000) and (2, 1) is significantly different, but I'm more interested in how this impacts internal operations and GPU memory management.
I'm not sure about the implications to GPU memory use – you'd probably have to profile it if you're interested in performance at that level.
For example:
The issue is that XLA has no
unique
primitive, so we're forced to implement it from scratch.unique
along an axis is essentially implemened as a lexicographic sort followed by a pairwise equality check, and in this case the lex-sort splits the array columns into 1000 key arrays passed to XLA, and this has performance implications.I brainstormed a bit with @mattjj about how we might better compute this in this case, but the answer isn't obvious.
I'm going to leave this open for now.