elixir-nx / scholar

Traditional machine learning on top of Nx
Apache License 2.0
434 stars 46 forks source link

Optimize t-sne #189

Closed josevalim closed 1 year ago

josevalim commented 1 year ago

Our t-SNE implementation is O(n²). There is Barnes-Hut implementation, which is O(n), but it requires quad-trees which are not trivial to reproduce with Nx. However, there is a CUDA implementation which we could borrow ideas from (associated paper).

josevalim commented 1 year ago

In case the CUDA version is not feasible (for example, one of its optimizations uses sparse matrices, which we don't currently have), we can consider implementing Flt-SNE. It may even be a better starting point.

krstopro commented 1 year ago

@josevalim I've been trying to implement this, but I am afraid it might be very hard. The main reason is computing the pairwise distances necessary for conditional and joint probability distribution. In order to avoid $O(n^2)$ complexity (as currently done) all of the popular t-SNE methods are using (approximate) k-nearest neighbour (k-NN) search to compute $n \times k$ distances between every data point and its k nearest neighbours. However, the (exact) k-NN method implemented in Scholar also computes $n^2$ pairwise distances (before using topk). I don't see a way to implement this in $O(n \log n)$ at the moment.

josevalim commented 1 year ago

Thanks for the initial feedback @krstopro! Follow up questions:

  1. Can we compute the approximate kNN or does that require trees?

  2. Also, to be clear, do you mean both Flt-SNE and Barnes-Hut use approximate k-NN?

krstopro commented 1 year ago

Can we compute the approximate kNN or does that require trees?

Most likely requires trees. Popular approximate k-NN libraries such as Faiss or Annoy work with trees. NNDescent doesn't use trees (uses heap for storing neighbours; maybe this can be avoided by sacrificing some speed), but again is probably not easy to implement in Nx. Its Python implementation uses trees for the initialisation (didn't check the details).

Also, to be clear, do you mean both Flt-SNE and Barnes-Hut use approximate kNN?

Both of them. Original Barnes-Hut t-SNE implementation uses vantage-point trees to calculate the exact k-NN, but I guess nowadays everyone uses the approximate one.

Let me add that the recent alternatives to t-SNE (TriMap, PaCMAP) rely on approximate k-NN calculation. It is also useful for spectral clustering (currently an open issue).

josevalim commented 1 year ago

@krstopro UMAP, TriMap, and PaCMAP all send very promising. I will take a more in depth look at those too.

krstopro commented 1 year ago

@krstopro UMAP, TriMap, and PaCMAP all send very promising. I will take a more in depth look at those too.

I think the issue of efficiently computing the k-nearest neighbours still remains: UMAP is using sklearn.neighbors, TriMap is using pynndescent, PaCMAP is using Annoy.

Perhaps an important thing I forgot to mention. JAX (and probably XLA, can't see it in the official code) has min_approx function, but it's been optimized for TPUs (probably based on this paper).

josevalim commented 1 year ago

@krstopro graph algorithms for kNN may not necessarily be a bad idea because we can represent the graph as a triangular NxN matrix. Or would that be a bad idea?

krstopro commented 1 year ago

@krstopro graph algorithms for kNN may not necessarily be a bad idea because we can represent the graph as a triangular NxN matrix. Or would that be a bad idea?

I think so, since we are trying to avoid $O(N^2)$ complexity. Also, a lot of memory - here $N$ is the dataset size.

josevalim commented 1 year ago

max/min_approx seem to use approximate top-k which is a MLIR operation: https://www.tensorflow.org/mlir/tf_ops (which we are transitioning to).

@krstopro looking at the Efficient K-Nearest Neighbor Graph Construction paper, we should be able to avoid the time complexity altogether. I suspect XLA would be able to optimize the matrix updates so we only pay for the additional space allocation once (which is when we allocate the graph matrix). The matrix should never really be copied any other time.

krstopro commented 1 year ago

@krstopro looking at the Efficient K-Nearest Neighbor Graph Construction paper, we should be able to avoid the time complexity altogether. I suspect XLA would be able to optimize the matrix updates so we only pay for the additional space allocation once (which is when we allocate the graph matrix). The matrix should never really be copied any other time.

Sorry, I made a mistake. The k-nearest neighbor graph is $N \times k$ matrix, which is completely acceptable (even to store).

NNDescent might be doable in Nx. Will have a closer look.

krstopro commented 1 year ago

@josevalim Alright, had a look. I think the problem of doing this in Nx would be sampling the reverse graph (whose node degrees depends on the data; they are not equal to k) and performing unions between candidate sets. In PyNNDescent they are using sparse matrices and operations to do some of these things (e.g. maximum between two sparse matrices). One could use dense $n \times n$ matrices, but I think that would be similar to the current implementation of exact k-NN in Scholar.

Note though. I didn't try using Scholar.Neighbors.KNearestNeighbors for t-SNE. My main concern was computing the $n \times n$ affinity matrix (i.e. pairwise distances between all n points).

josevalim commented 1 year ago

My notes so far:

I may be completely wrong here, but it seems the next step is to define a sampling algorithm which given:

  1. an output size n
  2. a tensor t of shape {w, ...}
  3. a function f that receives each element of t across w and returns 0 or 1

Will return a vector of dimension n pointing to random indexes of t across w where f returned true.

krstopro commented 1 year ago

B[v] is organized as a heap, where updates cost O(log K), so I am assuming B[v] only keeps K elements. This would be a u32[n,k] matrix that stores indexes that point to the f32[n,n] table. We could use positive/negative signs on the indices as the flag used by incremental search or have a separate u8[k,n] matrix for that.

@josevalim The problem is, how do you reverse such a matrix? By reversing, I mean computing the matrix of the reverse k-NN graph R. If B[v] is holds indices of nearest neighbours of v, then R[v] should hold the indices of the rows of B where v appears in one of the columns. Also note that the size of R[v] depends on the data; it ranges from 0 (no node has v as nearest neighbour) to $n - 1$ (all other nodes have v as nearest neighbour). I don't see a way to do this efficiently in Nx.

Another way to look at it is the following. Let B be a sparse $n \times n$ matrix where B[u][v] = 1 iff v is among k-nearest neighbours of u; I think this is exactly the approach they took in PyNNDescent. We keep tensors of rows and columns where B is equal to one ($n \times k$ in total). Then R is equal to B transposed, i.e. we just swap the two tensors of rows and columns. I don't see a way to do the other operations (union, sampling, etc.) without making B dense.

josevalim commented 1 year ago

Take b[n,k] is a matrix that, for a given i, it returns me the indices of the kNN of i. To compute the reverse of a given v, where v is given by its index, the following pseudo code would answer it:

def rv(b, v, p, k) do
  size = p*k
  rpk = Nx.tensor(-1) |> Nx.broadcast({size})
  rpk_i = 0
  i = 0

  while i <- 0.. Nx.size(b), i != v do
    if Nx.any(b[i] == v) and sampling_rate(...) do
      rpk[rpk_i] = i
      rpk_i = if rpk_i == size - 1, do: 0, else: rpk_i + 1
    end
    i = i + 1
  end
end

The sampling_rate? is the reservoir method that tells us if we should store that entry on the sample.

krstopro commented 1 year ago

@josevalim Correct me if I'm wrong, but this would mean we have to use a double loop to reverse entire B. Do we really want that?

josevalim commented 1 year ago

You are right and I have been thinking about it. The loop above can be easily vectorized as:

def rv(b, v, p, k) do
  sample_ones(Nx.any(b[i] == v, axis: 1), p * k)
end

Sampling is still linear. And we could most likely batch the outer loop so the sampling is of size (p*K)*n*batch_size but we do more in parallel.

But also I want to re-read the full algorithm (nr 2) because we may gain properties if we fold the reversal into the second parallel for loop.

josevalim commented 1 year ago

My suggestion is to actually implement a mix of Algorithm 1 and Algorithm 2 from the paper. The Local Join optimization from section 2.3 introduces parallelism with the purpose of increasing data locality in a distributed setting (not a benefit for us) at the cost of synchronization (not an option in XLA).

Here is the general outline of the algorithm:

# B is actually two matrices, one for indexes (s32[n][k]) and
# another for distances (f32[n][k]). Initial distances are infinite.
# We can use positive/negative values in B as the new/old flag 
# or a third u8[n][k] matrix.
B[v] <− Sample(V, K) × {⟨∞, true⟩}

loop do
  # We need to compute two matrices containing the kNN(v) ∪ rkNN(V).
  # They are called oldU (s32[n,k+pk]) and newU (s32[n,2pk]).
  # We initialize both with negative values to mark unused spots.
  # To compute newU, we need an intermediate "new" matrix s32[n,pk].
  for v ∈ V do
    oldU[v] <− (all items in B[v] with a false flag) ∪ reverseSample(v, B, ρK)
    new[v] <− ρK items in B[v] with a true flag
    Mark sampled items in B[v] as false;

  for v ∈ V do
    newU[v] <- new[v] ∪ reverseSample(v, new, ρK)

  c <− 0 # update counter

  for v ∈ V do
    for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
      l <− σ(v, u2)
      c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)

  return B if c < δNK

Where reverseSample is the rv function above. This algorithm should be parallelizable at the GPU (at the cost of memory increase).

josevalim commented 1 year ago

Ok, I have finished my changes. Because the algorithm is 3 independent loops, we can batch at the cost of memory usage.

josevalim commented 1 year ago

Here is the implementation of two reservoir sampling algorithms in Elixir:

Mix.install([:benchee])

defmodule RS do
  def r(enumerable, count) when count in 0..128 do
    sample = Tuple.duplicate(nil, count)

    reducer = fn elem, {idx, sample} ->
      jdx = random_index(idx)

      cond do
        idx < count ->
          value = elem(sample, jdx)
          {idx + 1, put_elem(sample, idx, value) |> put_elem(jdx, elem)}

        jdx < count ->
          {idx + 1, put_elem(sample, jdx, elem)}

        true ->
          {idx + 1, sample}
      end
    end

    {size, sample} = Enum.reduce(enumerable, {0, sample}, reducer)
    sample |> Tuple.to_list() |> Enum.take(Kernel.min(count, size))
  end

  defp random_index(0), do: 0
  defp random_index(idx), do: :rand.uniform(idx + 1) - 1

  def l(enumerable, count) when count in 0..128 do
    sample = Tuple.duplicate(nil, count)

    reducer = fn elem, {idx, jdx, w, sample} ->
      cond do
        idx == jdx ->
          w = w * :math.exp(:math.log(:rand.uniform()) / count)
          jdx = idx + floor(:math.log(:rand.uniform()) / :math.log(1 - w)) + 1
          pos = if idx < count, do: idx, else: :rand.uniform(count) - 1
          {idx + 1, jdx, w, put_elem(sample, pos, elem)}

        idx < count ->
          {idx + 1, jdx, w, put_elem(sample, idx, elem)}

        true ->
          {idx + 1, jdx, w, sample}
      end
    end

    {size, _, _, sample} = Enum.reduce(enumerable, {0, count - 1, 1.0, sample}, reducer)
    sample |> Tuple.to_list() |> Enum.take(Kernel.min(count, size))
  end
end

enum = Enum.to_list(1..1_000_000)
IO.inspect(RS.r(enum, 128))
IO.inspect(RS.l(enum, 128))

# list = Enum.to_list(1..10_000)
# map_fun = fn i -> [i, i * i] end

Benchee.run(
  %{
    "r" => fn -> RS.r(enum, 8) end,
    "l" => fn -> RS.l(enum, 8) end
  },
  time: 5,
  memory_time: 2,
  warmup: 0
)

We use the first one in Elixir but the second one is consistently faster, so I will improve the implementation in Elixir too. They are algorithm R and L retrospectively from the Wikipedia page: https://en.wikipedia.org/wiki/Reservoir_sampling

krstopro commented 1 year ago

@josevalim I'm still digesting this stuff. 😅 Both the paper and what you wrote.

josevalim commented 1 year ago

I can explain my thought process, in case it helps. This is algorithm 2:

B[v] <- Sample(V, K) × {⟨∞, true⟩}

loop do
  for v ∈ V do
    old[v] <− all items in B[v] with a false flag
    new[v] <− ρK items in B[v] with a true flag
    Mark sampled items in B[v] as false;

  old′ <- Reverse(old), new′ <- Reverse(new)
  c <- 0 //update counter

  for v ∈ V do
    oldU[v] <− old[v] ∪ Sample(old′[v], ρK)
    newU[v] <− new[v] ∪ Sample(new′[v], ρK)

    for (u1,u2 ∈ newU[v], u1<u2) or (u1 ∈ newU[v], u2 ∈ oldU[v]) do
      l <− σ (u1, u2)
      // c and B[.] are synchronized.
      c <− c + UpdateNN(B[u1], ⟨u2, l, true⟩)
      c <− c + UpdateNN(B[u2], ⟨u1, l, true⟩)

  return B if c < δNK

The biggest issue is that the last loop was designed for MapReduce-style distribution. So instead of updating neighbours of v, we need to rollback to update v itself, like in Algorithm 1, this yields:

B[v] <- Sample(V, K) × {⟨∞, true⟩}

loop do
  for v ∈ V do
    old[v] <− all items in B[v] with a false flag
    new[v] <− ρK items in B[v] with a true flag
    Mark sampled items in B[v] as false;

  old′ <- Reverse(old), new′ <- Reverse(new)
  c <- 0 //update counter

  for v ∈ V do
    oldU[v] <− old[v] ∪ Sample(old′[v], ρK)
    newU[v] <− new[v] ∪ Sample(new′[v], ρK)

    for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
      l <− σ(v, u2)
      c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)

  return B if c < δNK

In other words, we visit:

  1. all new neighbours of v and all of its old neighbours
  2. all old neighbours of v and all of its new neighbours
  3. all new neighbours of v and all of its new neighbours

The next step is to compute the reverse. As per above, we don't want to compute a very large reverse and then sample it. We want to sample it at the same time we reverse it. So I need to move

    old[v] <− old[v] ∪ Sample(old′[v], ρK)
    new[v] <− new[v] ∪ Sample(new′[v], ρK)

to its own loop, which yields:

B[v] ←− Sample(V, K) × {⟨∞, true⟩}
loop
  for v ∈ V do
    old[v] <− all items in B[v] with a false flag
    new[v] <− ρK items in B[v] with a true flag
    Mark sampled items in B[v] as false;

  for v ∈ V do
    oldU[v] <− old[v] ∪ reverseSample(v, old,ρK)
    newU[v] <− new[v] ∪ reverseSample(v, new, ρK)

  c <- 0 //update counter

  for v ∈ V do
    for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
      l <− σ(v, u2)
      c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)

  return B if c < δNK

The last step is to realize that we don't need to build old[v] on the first loop and then reverse it. old[v] is simply a view of B, which we can compute as we reverse sample, so we skip building old[v] and build oldU[v] directly in the first loop:

B[v] ←− Sample(V, K) × {⟨∞, true⟩}
loop
  for v ∈ V do
    oldU[v] <− (all items in B[v] with a false flag) ∪ reverseSample(v, B, ρK)
    new[v] <− ρK items in B[v] with a true flag
    Mark sampled items in B[v] as false;

  for v ∈ V do
    newU[v] <− new[v] ∪ reverseSample(v, new, ρK)

  c <- 0 //update counter

  for v ∈ V do
    for (u1 ∈ newU[v], u1 >= 0, u2 ∈ oldU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ oldU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v)
        or (u1 ∈ newU[v], u1 >= 0, u2 ∈ newU[u1], u2 >= 0, u2 != v) do
      l <− σ(v, u2)
      c <− c + UpdateNN(B[v], ⟨u2, l, true⟩)

  return B if c < δNK
josevalim commented 1 year ago

I have been thinking more about this and I believe we can actually implement trees efficiently as long as the tree is balanced, built at once, and not mutated. My understanding is that most KDTrees fulfill this property. The idea is to lay the tree flat in memory. A tree like this:

        0
      /   \
    1      2
  /  \    /  \
3     4  5    6

Would be stored as 0123456 in memory. If the tree is balanced, we can compute the height/depth upfront from the number of elements by subtracting the precision from the highest bit of the total:

def depth(count) do
  Nx.subtract(32, Nx.count_leading_zeros(Nx.u32(count)))
end

And from the depth, we allocate a tensor of size (2^depth)-1. In the worst case scenario, we waste 2^(depth-1)-1 memory (i.e. the last layer has a single element).

Traversing the tree requires three arguments: the current depth, the current offset, and the historic of decision if we turned left or right:

def traverse(tensor, target, depth, offset, position) do
  value = tensor[offset + position]

  cond do
    value == target -> offset + position
    value < target -> traverse(tensor, target, depth + 1, offset + 2**depth, position <<< 1)
    value > target -> traverse(tensor, target, depth + 1, offset + 2**depth, (position <<< 1) + 1)
  else
end

For example, imagine this tree:

        d
      /   \
    b      f
  /  \    /  \
a     c  e    g

And we want to get to e:

traverse(depth, offset, position)
traverse(0, 0, 0) #=> access a
traverse(1, 1, 1) #=> access f
traverse(2, 3, 2) #=> access e

This should allow us to build a KDTree and perhaps optimize kNN itself. BarnetHut use vintage point and quadtrees, which I am not sure would be suitable. However, EFANNA builds read-only balanced truncated randomized KDTrees (the truncated parts means the last depth contains N entries), so it is a suitable option. For completeness, I am also looking into LSH-based kNN.

josevalim commented 1 year ago

Of course, in order to balance trees we need to compute the median and currently we don't have a fast implementation. It is O(n*logn), so we would need to solve that first.

josevalim commented 1 year ago

After scourging the internet for median algorithms we can trivially implement with Nx, I found LazySelect (section 12.2.1). Here is a numpy implementation. Improving our median would also have benefits for Affinity Propagation.

krstopro commented 1 year ago

@josevalim This is kinda like partition used for Quicksort, right?

josevalim commented 1 year ago

QuickSelect is the one that partitions but it is not very efficient on Nx. :(

krstopro commented 1 year ago

@josevalim I see. I am not exactly sure why this method is linear since it does

B = np.random.choice(A, selection_len)
B = np.sort(B)

and selection_len = int((length ** (3/4))), but lemme have a closer look.

krstopro commented 1 year ago

@josevalim Oh, it's actually a well known algorithm (that I wasn't familiar with!). Going through it now!

Update: the main difference is that it's randomised so we would need to include a random key, but I guess that's just a detail.

josevalim commented 1 year ago

@krstopro yes, you can see in Scholar how we generate a random k when necessary.

And I will try to implement KDTree today, using our current median, but we can benchmark other versions once ready. The only limitation with our KDTree is that it has to be done outside defn and it will require log2(N) compilations of the median algorithm.

msluszniak commented 1 year ago

Implementing KD-trees using static array is actually very good idea. I knew about such implementations for heaps, but I haven't thought about it in terms of KD-trees.

krstopro commented 1 year ago

Implementing KD-trees using static array is actually very good idea. I knew about such implementations for heaps, but I haven't thought about it in terms of KD-trees.

I think this is something similar to what @josevalim wrote above. but they are not using medians for pivots

I am looking into Locality Sensitive Hashing via Random Projections (e.g. sklean implementation). I might be able to vectorise it in Nx.

krstopro commented 1 year ago

Went through GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees, I think there is everything in Nx to implement it. Only thing I am not sure about are the bitwise operators used.

josevalim commented 1 year ago

Unfortunately we don't have the variadic/parallel sort, which is what he uses to sort by tags and data at once.

We can implement it for Nx, but it would only work on XLA and not Torchx. And XLA's variadic sort requires all inputs to have the same shape, which means we wouldn't be able sort tag and data as is together. We would need to slice a dimension of data and sorting this dimension of data would not give the whole picture. So we need to parallel sort the tag, data, and an iota. The iota would give us the indices of the data and then we would gather the indices to put the data in the new order, which I am not very sure if it is worth it. So I will proceed with a recursive version of this algorithm, so the tree is left balanced, and we will see how it will go :D

josevalim commented 1 year ago

To keep track, here are all of the things we could still implement (not yet ruled out):

And this may unblock us to do UMAP, TriMap, and PaCMAP.

krstopro commented 1 year ago

Unfortunately we don't have the variadic/parallel sort, which is what he uses to sort by tags and data at once.

@josevalim I don't think it's needed. If points[.., dim] are in $[0, 1)$ interval (1 excluded) then sorting with

bool less ( int idx_a , int idx_b , int l) {
  int dim = l % k;
  return
    ( tags [ idx_a ] < tags [ idx_b ])
      || ( tags [ idx_a ] == tags [ idx_b ])
      && ( points [ idx_a ][ dim ] < points [ idx_b ][ dim ];
}

is equivalent to argsorting tags + points[:, dim], right? For general points[.., dim] we can sort (max(points[.., dim]) + 1) * tags + points[:, dim]. I might be wrong though.

josevalim commented 1 year ago

I thought about that but we would have corner cases in all scenarios. There is nothing stopping the maximum value for s64 being in the tensor. :(

krstopro commented 1 year ago

I thought about that but we would have corner cases in all scenarios. There is nothing stopping the maximum value for s64 being in the tensor. :(

Correct, but we can always preprocess the point by normalising them (this is common for t-SNE I think). Or equivalently, sort tags + points[:, dim] / (max(points[.., dim]) + 1). Would this work? Update: No, it wouldn't. :(

krstopro commented 1 year ago

Anyway, I am looking into LSH at the moment. Will see what I can do.

josevalim commented 1 year ago

@krstopro I guess we could have two functions for building KDTrees. One would expect the amplitude of the tensor. For example, if a tensor goes between -1 and 1, amplitude of 2 would be enough (and I would start the tags from 1 for convenience).

krstopro commented 1 year ago

@krstopro I guess we could have two functions for building KDTrees. One would expect the amplitude of the tensor. For example, if a tensor goes between -1 and 1, amplitude of 2 would be enough (and I would start the tags from 1 for convenience).

@josevalim That sounds like a solution to me.

josevalim commented 1 year ago

This notebook implements KDTree: https://gist.github.com/josevalim/555330a5cd4347ffe54b180be5b4a5c5

There are some TODOs but it requires only some quick polishing. I will optimize and test it tomorrow. I will try to implement the amplified version tomorrow as well.

josevalim commented 1 year ago

I will open up a new issue with the exploration points of this one.

josevalim commented 1 year ago

See #207.