lmcinnes / umap

Uniform Manifold Approximation and Projection
BSD 3-Clause "New" or "Revised" License
7.42k stars 806 forks source link

More efficient approach to sort the distance matrix in transform of UMAP #114

Closed johannfaouzi closed 6 years ago

johannfaouzi commented 6 years ago

Introduction

As you pointed it out in your code, the distance matrix is sorted two times: https://github.com/lmcinnes/umap/blob/d286224a86dfe730c76f082c6f523b69b57bd0f8/umap/umap_.py#L1557-L1564

Since np.argsort() already sorts the indices along the last axis, it is possible to use them in order to index the distance matrix. Indices for the first axis are needed, but they are trivial: it's a matrix M where: M[i, j] = i

Test

import numpy as np
import numba
from sklearn.metrics import pairwise_distances

dmat = np.arange(100 * 40).reshape(100, 40)
indices = np.argsort(dmat)

dists = np.sort(dmat)

indices_row = np.zeros((100, 40), dtype=np.int64)
for i in range(100):
    for j in range(40):
        indices_row[i, j] = i
dists_new = dmat[indices_row, indices]

np.all(dists == dists_new)

which returns True

Possible solutions

I had several ideas to create the indices for the first axis:

def method_one(dmat, indices_col):
    indices_row = np.repeat(np.arange(dmat.shape[0]), dmat.shape[1]).reshape(dmat.shape)
    return dmat[indices_row, indices]

def method_two(dmat, indices_col):
    indices_row = np.multiply(np.ones(dmat.shape[1], dtype=np.int64),
                              np.arange(dmat.shape[0], dtype=np.int64)[:, None])
    return dmat[indices_row, indices]

@numba.njit(parallel=True)
def numba_one(dmat, indices_col):
    n_samples_transform, n_samples_fit = dmat.shape
    res = np.zeros((n_samples_transform, n_samples_fit), dtype=np.int64)
    for i in numba.prange(n_samples_transform):
        for j in numba.prange(n_samples_fit):
            res[i, j] = i
    return res

@numba.njit(parallel=True)
def numba_two(dmat, indices_col):
    n_samples_transform, n_samples_fit = dmat.shape
    res = np.zeros((n_samples_transform, n_samples_fit), dtype=np.int64)
    for i in numba.prange(n_samples_transform):
        res[i] = i * np.ones(n_samples_fit, dtype=np.int64)
    return res

@numba.njit(parallel=True)
def numba_three(dmat, indices_col):
    n_samples_transform, n_samples_fit = dmat.shape
    res = np.zeros((n_samples_transform, n_samples_fit), dtype=np.int64)
    for i in numba.prange(n_samples_fit):
        res[:, i] = np.arange(n_samples_transform)
    return res

Performance evaluation

I tested the performance of these functions in several settings.

Small sample size (n_samples_fit >> n_samples_transform)

rng = np.random.RandomState(123)

X_fit = rng.randn(400, 32)
X_transform = rng.randn(20, 32)

dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 10000 np.sort(dmat)
312 µs ± 5.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_one(dmat, indices)
77 µs ± 522 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_two(dmat, indices)
59.5 µs ± 807 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
121 µs ± 34.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000 
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
123 µs ± 28.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000 
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
170 µs ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Small sample size (n_samples_fit << n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(20, 32)
X_transform = rng.randn(400, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 10000 np.sort(dmat)
139 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_one(dmat, indices)
79 µs ± 789 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_two(dmat, indices)
62.8 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
113 µs ± 8.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
199 µs ± 5.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
140 µs ± 6.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

High sample size (n_samples_fit >> n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(4000, 32)
X_transform = rng.randn(600, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 5 np.sort(dmat)
136 ms ± 1.71 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_one(dmat, indices)
37.7 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_two(dmat, indices)
32.3 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
36.2 ms ± 3.82 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
35.6 ms ± 2.46 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
34.9 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

High sample size (n_samples_fit ~ n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(4000, 32)
X_transform = rng.randn(6000, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 5 np.sort(dmat)
1.41 s ± 6.46 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_one(dmat, indices)
492 ms ± 7.74 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_two(dmat, indices)
427 ms ± 5.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
434 ms ± 36.6 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
427 ms ± 32.1 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
597 ms ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

High sample size (n_samples_fit << n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(4000, 32)
X_transform = rng.randn(60000, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 1 np.sort(dmat)
14.5 s ± 661 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 1 method_one(dmat, indices)
6.8 s ± 2.77 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 1 method_two(dmat, indices)
4.45 s ± 237 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit -n 1
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
4.29 s ± 184 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit -n 1
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
4.15 s ± 413 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit -n 1
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
8.18 s ± 460 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Conclusion

It looks like that method_two and numba_one are a little bit better. I thought that it would have led to a higher increase of performance, but indexing a big array takes most of the time in the latest setting. The results are a bit underwhelming. Having a mask would probably lead to better results for the indexing part, but I don't know if it would take more time to create a mask. I'm going to try more things.

I kept n_samples_fit below 4096 because this piece of code is only run if self._small_data is True. I don't know if there is a similar issue otherwise (I didn't look at deheap_sort).

johannfaouzi commented 6 years ago

Doing everything in a numba.jitted function seems to be a bit better:

@numba.njit(parallel=True)
def numba_four(dmat, indices_col):
    n_samples_transform, n_samples_fit = dmat.shape
    res = np.zeros((n_samples_transform, n_samples_fit), dtype=dmat.dtype)
    for i in numba.prange(n_samples_transform):
        for j in numba.prange(n_samples_fit):
            res[i, j] = dmat[i, indices_col[i, j]]
    return res

Which leads to the following results:

Small sample size (n_samples_fit >> n_samples_transform)

rng = np.random.RandomState(123)

X_fit = rng.randn(400, 32)
X_transform = rng.randn(20, 32)

dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 10000 np.sort(dmat)
312 µs ± 5.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_one(dmat, indices)
77 µs ± 522 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_two(dmat, indices)
59.5 µs ± 807 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
121 µs ± 34.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000 
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
123 µs ± 28.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000 
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
170 µs ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 numba_four(dmat, indices)
72.4 µs ± 22.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Small sample size (n_samples_fit << n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(20, 32)
X_transform = rng.randn(400, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 10000 np.sort(dmat)
139 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_one(dmat, indices)
79 µs ± 789 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 method_two(dmat, indices)
62.8 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
113 µs ± 8.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
199 µs ± 5.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit -n 10000
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
140 µs ± 6.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit -n 10000 numba_four(dmat, indices)
76.3 µs ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

High sample size (n_samples_fit >> n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(4000, 32)
X_transform = rng.randn(600, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 5 np.sort(dmat)
136 ms ± 1.71 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_one(dmat, indices)
37.7 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_two(dmat, indices)
32.3 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
36.2 ms ± 3.82 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
35.6 ms ± 2.46 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
34.9 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 numba_four(dmat, indices)
14.5 ms ± 1.68 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

High sample size (n_samples_fit ~ n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(4000, 32)
X_transform = rng.randn(6000, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 5 np.sort(dmat)
1.41 s ± 6.46 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_one(dmat, indices)
492 ms ± 7.74 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 method_two(dmat, indices)
427 ms ± 5.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
434 ms ± 36.6 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
427 ms ± 32.1 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%%timeit -n 5
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
597 ms ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 numba_four(dmat, indices)
145 ms ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

High sample size (n_samples_fit << n_samples_transform)

rng = np.random.RandomState(123)
​
X_fit = rng.randn(4000, 32)
X_transform = rng.randn(60000, 32)
​
dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")
indices = np.argsort(dmat)

%timeit -n 1 np.sort(dmat)
14.5 s ± 661 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 1 method_one(dmat, indices)
6.8 s ± 2.77 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 1 method_two(dmat, indices)
4.45 s ± 237 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit -n 1
indices_row = numba_one(dmat, indices)
dmat[indices_row, indices]
4.29 s ± 184 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit -n 1
indices_row = numba_two(dmat, indices)
dmat[indices_row, indices]
4.15 s ± 413 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit -n 1
indices_row = numba_three(dmat, indices)
dmat[indices_row, indices]
8.18 s ± 460 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 1 numba_four(dmat, indices)
1.33 s ± 59.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lmcinnes commented 6 years ago

Thanks, this looks like a thorough analysis. Realistically I mostly expect n_samples_fit >> n_samples_transform, so it's probably best to focus on that case. In general option numba_four seems like a pretty clear winner. If you would like to submit a PR for that I would be happy to accept it. In fact in general feel free to submit PRs -- you generally seem to diagnose the issues very well and often have fixes ready. I appreciate all the work you've been putting in, so you should get credit for the fixes as well!

johannfaouzi commented 6 years ago

Usually I prefer waiting for you to reply first for several reasons:

I feel like it's better to submit a PR afterwards, so that there's no confusion. Or maybe it should be done after the first submission of the PR, I don't know what is the proper way to do this.

For this issue, I have several questions for the fix:

johannfaouzi commented 6 years ago

There is also something that I would like to point out. After sorting the distance matrix, you only keep the first self.n_neighbors values. https://github.com/lmcinnes/umap/blob/d286224a86dfe730c76f082c6f523b69b57bd0f8/umap/umap_.py#L1563-L1564 Instead of sorting the whole array, it's probably interesting, in some scenarios, to use numpy.argpartition first, and then sort the distance matrix along the last axis. It should be worth it when self.n_neighborsis small, (i.e. much smaller than the number of samples in X_fit)

See this post on StackOverflow too.

johannfaouzi commented 6 years ago

I came up with this:

# Current solution
def current(dmat, n_neighbors):
    indices = np.argsort(dmat) 
    dists = np.sort(dmat)
    indices = indices[:, :n_neighbors] 
    dists = dists[:, :n_neighbors]
    return indices, dists

# Former proposed solution
@numba.njit(parallel=True)
def former_sorting_matrix(dmat, indices_col):
    n_samples_transform, n_samples_fit = dmat.shape
    res = np.zeros((n_samples_transform, n_samples_fit), dtype=dmat.dtype)
    for i in numba.prange(n_samples_transform):
        for j in numba.prange(n_samples_fit):
            res[i, j] = dmat[i, indices_col[i, j]]
    return res

def former_proposed_solution(dmat, n_neighbors):
    indices = np.argsort(dmat)
    dists = former_sorting_matrix(dmat, indices)
    indices = indices[:, :n_neighbors] 
    dists = dists[:, :n_neighbors]
    return indices, dists

# Proposed solution
@numba.njit(parallel=True)
def sorting_matrix(dmat, indices_col, n_neighbors):
    n_samples_transform, n_samples_fit = dmat.shape
    res = np.zeros((n_samples_transform, n_neighbors), dtype=dmat.dtype)
    for i in numba.prange(n_samples_transform):
        for j in numba.prange(n_neighbors):
            res[i, j] = dmat[i, indices_col[i, j]]
    return res

def proposed_solution(dmat, n_neighbors):
    indices = np.argpartition(dmat, n_neighbors)[:, :n_neighbors]
    dmat_shortened = sorting_matrix(dmat, indices, n_neighbors)
    indices_sorted = np.argsort(dmat_shortened)
    dists = sorting_matrix(dmat_shortened, indices_sorted, n_neighbors)
    indices_final = sorting_matrix(indices, indices_sorted, n_neighbors)
    return indices_final, dists

I tested the three possibilities with n_neighbors = 4 in different settings. A substantial difference is that the sorting of the indices is also taken into consideration in the computational cost this time. With a higher number of neighbors, the increase in performance should be smaller (or even negative if n_samples_fit ~ n_neighbors). I assume that n_neighbors << n_samples_fit is the most common scenario.

Small sample size (n_samples_fit >> n_samples_transform)

rng = np.random.RandomState(123)

n_neighbors = 4

X_fit = rng.randn(400, 32)
X_transform = rng.randn(20, 32)

dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")

indices_current, dists_current = current(dmat, n_neighbors)
indices_old, dists_old = former_proposed_solution(dmat, n_neighbors)
indices_new, dists_new = proposed_solution(dmat, n_neighbors)

print(np.all(indices_current == indices_old))
True
print(np.allclose(dists_current, dists_old))
True
print(np.all(indices_old == indices_new))
True
print(np.allclose(dists_old, dists_new))
True

%timeit -n 100 current(dmat, n_neighbors)
694 µs ± 24.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit -n 100 former_proposed_solution(dmat, n_neighbors)
422 µs ± 25.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit -n 100 proposed_solution(dmat, n_neighbors)
312 µs ± 71.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

High sample size (n_samples_fit >> n_samples_transform)

rng = np.random.RandomState(123)

n_neighbors = 4

X_fit = rng.randn(4000, 32)
X_transform = rng.randn(600, 32)

dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")

%timeit -n 100 current(dmat, n_neighbors)
289 ms ± 4.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 100 former_proposed_solution(dmat, n_neighbors)
168 ms ± 3.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 100 proposed_solution(dmat, n_neighbors)
35.9 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

High sample size (n_samples_fit ~ n_samples_transform)

rng = np.random.RandomState(123)

n_neighbors = 4

X_fit = rng.randn(4000, 32)
X_transform = rng.randn(6000, 32)

dmat = pairwise_distances(X_transform, X_fit, metric="euclidean")

%timeit -n 100 current(dmat, n_neighbors)
3.02 s ± 45.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 100 former_proposed_solution(dmat, n_neighbors)
1.73 s ± 34.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 100 proposed_solution(dmat, n_neighbors)
440 ms ± 19.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lmcinnes commented 6 years ago

Looks good to me. I believe I had a slight aversion to argpartition because it caused issues in an earlier project, but I recall now that it was that it was not supported by earlier versions of numpy (that the version of scikit-learn of the time claimed to support). The current umap requirements are for a very recent numpy, so all should be fine -- I don't mind not supporting old code in this case.

As for where to put it -- utils.py seems the sensible place, its short and generic enough that it may get used again elsewhere. Please put together a PR whenever you are ready.

lmcinnes commented 6 years ago

Merged. Thanks!