Closed josevalim closed 1 year ago
In case the CUDA version is not feasible (for example, one of its optimizations uses sparse matrices, which we don't currently have), we can consider implementing Flt-SNE. It may even be a better starting point.
@josevalim I've been trying to implement this, but I am afraid it might be very hard. The main reason is computing the pairwise distances necessary for conditional and joint probability distribution. In order to avoid $O(n^2)$ complexity (as currently done) all of the popular t-SNE methods are using (approximate) k-nearest neighbour (k-NN) search to compute $n \times k$ distances between every data point and its k nearest neighbours. However, the (exact) k-NN method implemented in Scholar
also computes $n^2$ pairwise distances (before using topk
). I don't see a way to implement this in $O(n \log n)$ at the moment.
Thanks for the initial feedback @krstopro! Follow up questions:
Can we compute the approximate kNN or does that require trees?
Also, to be clear, do you mean both Flt-SNE and Barnes-Hut use approximate k-NN?
Can we compute the approximate kNN or does that require trees?
Most likely requires trees. Popular approximate k-NN libraries such as Faiss or Annoy work with trees. NNDescent doesn't use trees (uses heap for storing neighbours; maybe this can be avoided by sacrificing some speed), but again is probably not easy to implement in Nx
. Its Python implementation uses trees for the initialisation (didn't check the details).
Also, to be clear, do you mean both Flt-SNE and Barnes-Hut use approximate kNN?
Both of them. Original Barnes-Hut t-SNE implementation uses vantage-point trees to calculate the exact k-NN, but I guess nowadays everyone uses the approximate one.
Let me add that the recent alternatives to t-SNE (TriMap, PaCMAP) rely on approximate k-NN calculation. It is also useful for spectral clustering (currently an open issue).
@krstopro UMAP, TriMap, and PaCMAP all send very promising. I will take a more in depth look at those too.
@krstopro UMAP, TriMap, and PaCMAP all send very promising. I will take a more in depth look at those too.
I think the issue of efficiently computing the k-nearest neighbours still remains: UMAP is using sklearn.neighbors, TriMap is using pynndescent, PaCMAP is using Annoy.
Perhaps an important thing I forgot to mention. JAX (and probably XLA, can't see it in the official code) has min_approx
function, but it's been optimized for TPUs (probably based on this paper).
@krstopro graph algorithms for kNN may not necessarily be a bad idea because we can represent the graph as a triangular NxN matrix. Or would that be a bad idea?
@krstopro graph algorithms for kNN may not necessarily be a bad idea because we can represent the graph as a triangular NxN matrix. Or would that be a bad idea?
I think so, since we are trying to avoid $O(N^2)$ complexity. Also, a lot of memory - here $N$ is the dataset size.
max/min_approx seem to use approximate top-k which is a MLIR operation: https://www.tensorflow.org/mlir/tf_ops (which we are transitioning to).
@krstopro looking at the Efficient K-Nearest Neighbor Graph Construction paper, we should be able to avoid the time complexity altogether. I suspect XLA would be able to optimize the matrix updates so we only pay for the additional space allocation once (which is when we allocate the graph matrix). The matrix should never really be copied any other time.
@krstopro looking at the Efficient K-Nearest Neighbor Graph Construction paper, we should be able to avoid the time complexity altogether. I suspect XLA would be able to optimize the matrix updates so we only pay for the additional space allocation once (which is when we allocate the graph matrix). The matrix should never really be copied any other time.
Sorry, I made a mistake. The k-nearest neighbor graph is $N \times k$ matrix, which is completely acceptable (even to store).
NNDescent might be doable in Nx
. Will have a closer look.
@josevalim Alright, had a look. I think the problem of doing this in Nx
would be sampling the reverse graph (whose node degrees depends on the data; they are not equal to k) and performing unions between candidate sets. In PyNNDescent they are using sparse matrices and operations to do some of these things (e.g. maximum between two sparse matrices). One could use dense $n \times n$ matrices, but I think that would be similar to the current implementation of exact k-NN in Scholar
.
Note though. I didn't try using Scholar.Neighbors.KNearestNeighbors
for t-SNE. My main concern was computing the $n \times n$ affinity matrix (i.e. pairwise distances between all n points).
My notes so far:
We can use a f32[n,n]
table to store the distance between two points, which is internal to the algorithm. The space requirement is a downside but the upside is that we can cache lookups and use the symmetric property of the distance functions to ensure we don't compute both distance(u1, u2)
and distance(u2, u1)
for a pair of vertices u1 and u2
B[v] is organized as a heap, where updates cost O(log K), so I am assuming B[v] only keeps K elements. This would be a u32[n,k]
matrix that stores indexes that point to the f32[n,n]
table. We could use positive/negative signs on the indices as the flag used by incremental search or have a separate u8[k,n]
matrix for that.
Section 2.5 outlines a sampling algorithm and the reverse kNN only considers p*K
entries with ρ ∈ (0,1]
, so the reverse table is of size (p*K)*n
. We will need a sampling algorithm that allows us to pick p*K
entries from a subset of n
rows which is of unknown size (I am aware of reservoir sampling but there may be others)
I may be completely wrong here, but it seems the next step is to define a sampling algorithm which given:
n
t
of shape {w, ...}
f
that receives each element of t
across w
and returns 0 or 1Will return a vector of dimension n
pointing to random indexes of t
across w
where f
returned true.
B[v] is organized as a heap, where updates cost O(log K), so I am assuming B[v] only keeps K elements. This would be a u32[n,k] matrix that stores indexes that point to the f32[n,n] table. We could use positive/negative signs on the indices as the flag used by incremental search or have a separate u8[k,n] matrix for that.
@josevalim The problem is, how do you reverse such a matrix? By reversing, I mean computing the matrix of the reverse k-NN graph R
. If B[v]
is holds indices of nearest neighbours of v
, then R[v]
should hold the indices of the rows of B
where v
appears in one of the columns. Also note that the size of R[v]
depends on the data; it ranges from 0 (no node has v
as nearest neighbour) to $n - 1$ (all other nodes have v
as nearest neighbour). I don't see a way to do this efficiently in Nx
.
Another way to look at it is the following. Let B
be a sparse $n \times n$ matrix where B[u][v] = 1
iff v
is among k-nearest neighbours of u
; I think this is exactly the approach they took in PyNNDescent. We keep tensors of rows and columns where B
is equal to one ($n \times k$ in total). Then R
is equal to B
transposed, i.e. we just swap the two tensors of rows and columns. I don't see a way to do the other operations (union, sampling, etc.) without making B
dense.
Take b[n,k]
is a matrix that, for a given i
, it returns me the indices of the kNN of i
. To compute the reverse of a given v
, where v
is given by its index, the following pseudo code would answer it:
def rv(b, v, p, k) do
size = p*k
rpk = Nx.tensor(-1) |> Nx.broadcast({size})
rpk_i = 0
i = 0
while i <- 0.. Nx.size(b), i != v do
if Nx.any(b[i] == v) and sampling_rate(...) do
rpk[rpk_i] = i
rpk_i = if rpk_i == size - 1, do: 0, else: rpk_i + 1
end
i = i + 1
end
end
The sampling_rate?
is the reservoir method that tells us if we should store that entry on the sample.
@josevalim Correct me if I'm wrong, but this would mean we have to use a double loop to reverse entire B
. Do we really want that?
You are right and I have been thinking about it. The loop above can be easily vectorized as:
def rv(b, v, p, k) do
sample_ones(Nx.any(b[i] == v, axis: 1), p * k)
end
Sampling is still linear. And we could most likely batch the outer loop so the sampling is of size (p*K)*n*batch_size
but we do more in parallel.
But also I want to re-read the full algorithm (nr 2) because we may gain properties if we fold the reversal into the second parallel for loop.
My suggestion is to actually implement a mix of Algorithm 1 and Algorithm 2 from the paper. The Local Join optimization from section 2.3 introduces parallelism with the purpose of increasing data locality in a distributed setting (not a benefit for us) at the cost of synchronization (not an option in XLA).
Here is the general outline of the algorithm:
# B is actually two matrices, one for indexes (s32[n][k]) and
# another for distances (f32[n][k]). Initial distances are infinite.
# We can use positive/negative values in B as the new/old flag
# or a third u8[n][k] matrix.
B[v] <− Sample(V, K) × {⟨∞, true⟩}
loop do
# We need to compute two matrices containing the kNN(v) ∪ rkNN(V).
# They are called oldU (s32[n,k+pk]) and newU (s32[n,2pk]).
# We initialize both with negative values to mark unused spots.
# To compute newU, we need an intermediate "new" matrix s32[n,pk].
for v ∈ V do
oldU[v] <− (all items in B[v] with a false flag) ∪ reverseSample(v, B, ρK)
new[v] <− ρK items in B[v] with a true flag
Mark sampled items in B[v] as false;
for v ∈ V do
newU[v] <- new[v] ∪ reverseSample(v, new, ρK)
c <− 0 # update counter
for v ∈ V do
for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
l <− σ(v, u2)
c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)
return B if c < δNK
Where reverseSample
is the rv
function above. This algorithm should be parallelizable at the GPU (at the cost of memory increase).
Ok, I have finished my changes. Because the algorithm is 3 independent loops, we can batch at the cost of memory usage.
Here is the implementation of two reservoir sampling algorithms in Elixir:
Mix.install([:benchee])
defmodule RS do
def r(enumerable, count) when count in 0..128 do
sample = Tuple.duplicate(nil, count)
reducer = fn elem, {idx, sample} ->
jdx = random_index(idx)
cond do
idx < count ->
value = elem(sample, jdx)
{idx + 1, put_elem(sample, idx, value) |> put_elem(jdx, elem)}
jdx < count ->
{idx + 1, put_elem(sample, jdx, elem)}
true ->
{idx + 1, sample}
end
end
{size, sample} = Enum.reduce(enumerable, {0, sample}, reducer)
sample |> Tuple.to_list() |> Enum.take(Kernel.min(count, size))
end
defp random_index(0), do: 0
defp random_index(idx), do: :rand.uniform(idx + 1) - 1
def l(enumerable, count) when count in 0..128 do
sample = Tuple.duplicate(nil, count)
reducer = fn elem, {idx, jdx, w, sample} ->
cond do
idx == jdx ->
w = w * :math.exp(:math.log(:rand.uniform()) / count)
jdx = idx + floor(:math.log(:rand.uniform()) / :math.log(1 - w)) + 1
pos = if idx < count, do: idx, else: :rand.uniform(count) - 1
{idx + 1, jdx, w, put_elem(sample, pos, elem)}
idx < count ->
{idx + 1, jdx, w, put_elem(sample, idx, elem)}
true ->
{idx + 1, jdx, w, sample}
end
end
{size, _, _, sample} = Enum.reduce(enumerable, {0, count - 1, 1.0, sample}, reducer)
sample |> Tuple.to_list() |> Enum.take(Kernel.min(count, size))
end
end
enum = Enum.to_list(1..1_000_000)
IO.inspect(RS.r(enum, 128))
IO.inspect(RS.l(enum, 128))
# list = Enum.to_list(1..10_000)
# map_fun = fn i -> [i, i * i] end
Benchee.run(
%{
"r" => fn -> RS.r(enum, 8) end,
"l" => fn -> RS.l(enum, 8) end
},
time: 5,
memory_time: 2,
warmup: 0
)
We use the first one in Elixir but the second one is consistently faster, so I will improve the implementation in Elixir too. They are algorithm R and L retrospectively from the Wikipedia page: https://en.wikipedia.org/wiki/Reservoir_sampling
@josevalim I'm still digesting this stuff. 😅 Both the paper and what you wrote.
I can explain my thought process, in case it helps. This is algorithm 2:
B[v] <- Sample(V, K) × {⟨∞, true⟩}
loop do
for v ∈ V do
old[v] <− all items in B[v] with a false flag
new[v] <− ρK items in B[v] with a true flag
Mark sampled items in B[v] as false;
old′ <- Reverse(old), new′ <- Reverse(new)
c <- 0 //update counter
for v ∈ V do
oldU[v] <− old[v] ∪ Sample(old′[v], ρK)
newU[v] <− new[v] ∪ Sample(new′[v], ρK)
for (u1,u2 ∈ newU[v], u1<u2) or (u1 ∈ newU[v], u2 ∈ oldU[v]) do
l <− σ (u1, u2)
// c and B[.] are synchronized.
c <− c + UpdateNN(B[u1], ⟨u2, l, true⟩)
c <− c + UpdateNN(B[u2], ⟨u1, l, true⟩)
return B if c < δNK
The biggest issue is that the last loop was designed for MapReduce-style distribution. So instead of updating neighbours of v
, we need to rollback to update v
itself, like in Algorithm 1, this yields:
B[v] <- Sample(V, K) × {⟨∞, true⟩}
loop do
for v ∈ V do
old[v] <− all items in B[v] with a false flag
new[v] <− ρK items in B[v] with a true flag
Mark sampled items in B[v] as false;
old′ <- Reverse(old), new′ <- Reverse(new)
c <- 0 //update counter
for v ∈ V do
oldU[v] <− old[v] ∪ Sample(old′[v], ρK)
newU[v] <− new[v] ∪ Sample(new′[v], ρK)
for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
l <− σ(v, u2)
c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)
return B if c < δNK
In other words, we visit:
The next step is to compute the reverse. As per above, we don't want to compute a very large reverse and then sample it. We want to sample it at the same time we reverse it. So I need to move
old[v] <− old[v] ∪ Sample(old′[v], ρK)
new[v] <− new[v] ∪ Sample(new′[v], ρK)
to its own loop, which yields:
B[v] ←− Sample(V, K) × {⟨∞, true⟩}
loop
for v ∈ V do
old[v] <− all items in B[v] with a false flag
new[v] <− ρK items in B[v] with a true flag
Mark sampled items in B[v] as false;
for v ∈ V do
oldU[v] <− old[v] ∪ reverseSample(v, old,ρK)
newU[v] <− new[v] ∪ reverseSample(v, new, ρK)
c <- 0 //update counter
for v ∈ V do
for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
l <− σ(v, u2)
c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)
return B if c < δNK
The last step is to realize that we don't need to build old[v]
on the first loop and then reverse it. old[v]
is simply a view of B, which we can compute as we reverse sample, so we skip building old[v]
and build oldU[v]
directly in the first loop:
B[v] ←− Sample(V, K) × {⟨∞, true⟩}
loop
for v ∈ V do
oldU[v] <− (all items in B[v] with a false flag) ∪ reverseSample(v, B, ρK)
new[v] <− ρK items in B[v] with a true flag
Mark sampled items in B[v] as false;
for v ∈ V do
newU[v] <− new[v] ∪ reverseSample(v, new, ρK)
c <- 0 //update counter
for v ∈ V do
for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
l <− σ(v, u2)
c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)
return B if c < δNK
I have been thinking more about this and I believe we can actually implement trees efficiently as long as the tree is balanced, built at once, and not mutated. My understanding is that most KDTrees fulfill this property. The idea is to lay the tree flat in memory. A tree like this:
0
/ \
1 2
/ \ / \
3 4 5 6
Would be stored as 0123456
in memory. If the tree is balanced, we can compute the height/depth upfront from the number of elements by subtracting the precision from the highest bit of the total:
def depth(count) do
Nx.subtract(32, Nx.count_leading_zeros(Nx.u32(count)))
end
And from the depth, we allocate a tensor of size (2^depth)-1
. In the worst case scenario, we waste 2^(depth-1)-1
memory (i.e. the last layer has a single element).
Traversing the tree requires three arguments: the current depth, the current offset, and the historic of decision if we turned left or right:
def traverse(tensor, target, depth, offset, position) do
value = tensor[offset + position]
cond do
value == target -> offset + position
value < target -> traverse(tensor, target, depth + 1, offset + 2**depth, position <<< 1)
value > target -> traverse(tensor, target, depth + 1, offset + 2**depth, (position <<< 1) + 1)
else
end
For example, imagine this tree:
d
/ \
b f
/ \ / \
a c e g
And we want to get to e
:
traverse(depth, offset, position)
traverse(0, 0, 0) #=> access a
traverse(1, 1, 1) #=> access f
traverse(2, 3, 2) #=> access e
This should allow us to build a KDTree and perhaps optimize kNN itself. BarnetHut use vintage point and quadtrees, which I am not sure would be suitable. However, EFANNA builds read-only balanced truncated randomized KDTrees (the truncated parts means the last depth contains N entries), so it is a suitable option. For completeness, I am also looking into LSH-based kNN.
After scourging the internet for median algorithms we can trivially implement with Nx, I found LazySelect (section 12.2.1). Here is a numpy implementation. Improving our median would also have benefits for Affinity Propagation.
@josevalim This is kinda like partition used for Quicksort, right?
QuickSelect is the one that partitions but it is not very efficient on Nx. :(
@josevalim I see. I am not exactly sure why this method is linear since it does
B = np.random.choice(A, selection_len)
B = np.sort(B)
and selection_len = int((length ** (3/4)))
, but lemme have a closer look.
@josevalim Oh, it's actually a well known algorithm (that I wasn't familiar with!). Going through it now!
Update: the main difference is that it's randomised so we would need to include a random key, but I guess that's just a detail.
@krstopro yes, you can see in Scholar how we generate a random k when necessary.
And I will try to implement KDTree today, using our current median, but we can benchmark other versions once ready. The only limitation with our KDTree is that it has to be done outside defn
and it will require log2(N)
compilations of the median algorithm.
Implementing KD-trees using static array is actually very good idea. I knew about such implementations for heaps, but I haven't thought about it in terms of KD-trees.
Implementing KD-trees using static array is actually very good idea. I knew about such implementations for heaps, but I haven't thought about it in terms of KD-trees.
I think this is something similar to what @josevalim wrote above. but they are not using medians for pivots
I am looking into Locality Sensitive Hashing via Random Projections (e.g. sklean implementation). I might be able to vectorise it in Nx
.
Went through GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees, I think there is everything in Nx
to implement it. Only thing I am not sure about are the bitwise operators used.
Unfortunately we don't have the variadic/parallel sort, which is what he uses to sort by tags and data at once.
We can implement it for Nx, but it would only work on XLA and not Torchx. And XLA's variadic sort requires all inputs to have the same shape, which means we wouldn't be able sort tag and data as is together. We would need to slice a dimension of data and sorting this dimension of data would not give the whole picture. So we need to parallel sort the tag, data, and an iota. The iota would give us the indices of the data and then we would gather the indices to put the data in the new order, which I am not very sure if it is worth it. So I will proceed with a recursive version of this algorithm, so the tree is left balanced, and we will see how it will go :D
To keep track, here are all of the things we could still implement (not yet ruled out):
And this may unblock us to do UMAP, TriMap, and PaCMAP.
Unfortunately we don't have the variadic/parallel sort, which is what he uses to sort by tags and data at once.
@josevalim I don't think it's needed. If points[.., dim]
are in $[0, 1)$ interval (1 excluded) then sorting with
bool less ( int idx_a , int idx_b , int l) {
int dim = l % k;
return
( tags [ idx_a ] < tags [ idx_b ])
|| ( tags [ idx_a ] == tags [ idx_b ])
&& ( points [ idx_a ][ dim ] < points [ idx_b ][ dim ];
}
is equivalent to argsorting tags + points[:, dim]
, right?
For general points[.., dim]
we can sort (max(points[.., dim]) + 1) * tags + points[:, dim]
.
I might be wrong though.
I thought about that but we would have corner cases in all scenarios. There is nothing stopping the maximum value for s64 being in the tensor. :(
I thought about that but we would have corner cases in all scenarios. There is nothing stopping the maximum value for s64 being in the tensor. :(
Correct, but we can always preprocess the point by normalising them (this is common for t-SNE I think).
Or equivalently, sort tags + points[:, dim] / (max(points[.., dim]) + 1)
. Would this work? Update: No, it wouldn't. :(
Anyway, I am looking into LSH at the moment. Will see what I can do.
@krstopro I guess we could have two functions for building KDTrees. One would expect the amplitude of the tensor. For example, if a tensor goes between -1 and 1, amplitude of 2 would be enough (and I would start the tags from 1 for convenience).
@krstopro I guess we could have two functions for building KDTrees. One would expect the amplitude of the tensor. For example, if a tensor goes between -1 and 1, amplitude of 2 would be enough (and I would start the tags from 1 for convenience).
@josevalim That sounds like a solution to me.
This notebook implements KDTree: https://gist.github.com/josevalim/555330a5cd4347ffe54b180be5b4a5c5
There are some TODOs but it requires only some quick polishing. I will optimize and test it tomorrow. I will try to implement the amplified version tomorrow as well.
I will open up a new issue with the exploration points of this one.
See #207.
Our t-SNE implementation is
O(n²)
. There is Barnes-Hut implementation, which isO(n)
, but it requires quad-trees which are not trivial to reproduce with Nx. However, there is a CUDA implementation which we could borrow ideas from (associated paper).