lmcinnes / pynndescent

A Python nearest neighbor descent for approximate nearest neighbors
BSD 2-Clause "Simplified" License
879 stars 105 forks source link

Speed-up Indexing with custom metrics #245

Open akurniawan opened 1 month ago

akurniawan commented 1 month ago

Hi, thanks for the great package! Currently I'm trying to build a kNN with custom metrics for DTW, but the index building for just 90k of data would take around 5 hours. Do you have any suggestions to improve this?

Below is the code that I used to build the index and perform a query

@jit(nopython=True, fastmath=True)
def dtw_numba(x, y):
    """
    Compute the Dynamic Time Warping (DTW) distance between two sequences.

    Parameters:
    x : array-like
        First sequence.
    y : array-like
        Second sequence.

    Returns:
    float
        The DTW distance between sequences x and y.
    """
    n, m = len(x), len(y)
    dtw_matrix = np.full((n + 1, m + 1), np.inf)
    dtw_matrix[0, 0] = 0

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = (x[i - 1] - y[j - 1]) ** 2
            dtw_matrix[i, j] = cost + min(dtw_matrix[i - 1, j],    # insertion
                                          dtw_matrix[i, j - 1],    # deletion
                                          dtw_matrix[i - 1, j - 1]) # match

    return np.sqrt(dtw_matrix[n, m])

import pynndescent

index = pynndescent.NNDescent(flat_inj_vecs, metric=dtw_numba)

index.query([flat_vecs[10]], k=100)

Thank you!

lmcinnes commented 1 month ago

I think the real catch here is the DTW is a pretty expensive metric to compute, so you are going to face a bit of an uphill battle no matter what you do. One good option might be ANNchor which is an ANN approach specifically designed for dealing with expensive metrics (essentially by finding cheap approximations to the metric tailored to the training data) and would probably be a good fir for your needs.