JuliaNeighbors / HNSW.jl

Approximate Nearest Neighbor Searches using the HNSW algorithm
MIT License
21 stars 6 forks source link

HNSW failing to find the nearest neighbor #26

Open tverho opened 3 years ago

tverho commented 3 years ago

I'm trying to replace my regular grid search with HNSW, but HNSW seems to fail rather spectacularly in finding the nearest neighbor in some cases. I understand that it's an approximate method, but it's a bit underwhelming if when I ask for 50 nearest neighbors, the closest of them is 3x farther than the actual nearest neighbor. What I'd really want is to get, say, 10 neighbors so that I could be fairly certain that at least ~5 of the actual nearest neighbors are be included.

Am I doing something wrong or is HNSW the wrong method for my need?

Below is a minimal example with my data (data file attached, they are nodes of a surface mesh). I've tried playing with the parameters of HierarchicalNSW but they don't seem to have much effect.

using DelimitedFiles
using HNSW
using Distances
using Formatting
using Random

""" Get grid index for coordinates """
grid_bin(coord, grid, cell) = trunc.(Int, mod.(coord./cell, 1) .* size(grid)) .+ [1,1,1]

""" Get point indices in the given and surrounding binds, with periodic boundaries """
function get_surrounding(bin::Array{Int, 1}, grid)
    idxs = []
    for i in -1:1, j in -1:1, k in -1:1
        b = bin .+ [i,j,k]
        b = mod1.(b, size(grid))
        push!(idxs, grid[b...])
    end
    return vcat(idxs...)
end

Random.seed!(1)

data = readdlm("points.csv")
points = collect(eachrow(data))
cell = [200., 200., 400.]

# Using normal Euclidean metric doesn't remove the problem
metric = PeriodicEuclidean(cell)

# HNSW init
hnsw = HierarchicalNSW(points, metric=metric)
add_to_graph!(hnsw)

# Grid search init
maxh = prod(cell)^(1/3) * 0.05
nx, ny, nz = trunc.(Int, cell*2/maxh) 
grid = reshape([Int[] for i in 1:nx*ny*nz], (nx,ny,nz))
for i in eachindex(points)
    bin = grid_bin(points[i], grid, cell)
    push!(grid[bin...], i)
end

maxiter = 500_000
for iter in 1:maxiter
    point = rand(3) .* cell

    # k=10 gives much worse results, k=100 gives somewhat better results
    k = 50
    idxs, dists = knn_search(hnsw, point, k)
    mindist = minimum(dists)

    bin = grid_bin(point, grid, cell)
    idxs2 = get_surrounding(bin, grid)
    dists2 = [metric(point, points[i]) for i in idxs2]
    # Grid search won't find neighbors if they are far away
    mindist2 = isempty(dists2) ? Inf : minimum(dists2)

   if mindist > mindist2 + 1e-10
        printfmtln("HNSW dist {:.3f}, grid search dist {:.3f}", mindist, mindist2)
   end
end

Output:

HNSW dist 25.863, grid search dist 16.112
HNSW dist 27.483, grid search dist 15.636
HNSW dist 22.306, grid search dist 18.034
HNSW dist 26.240, grid search dist 17.305
HNSW dist 28.344, grid search dist 14.457
HNSW dist 28.016, grid search dist 15.871
HNSW dist 27.789, grid search dist 15.590
HNSW dist 43.979, grid search dist 15.961
HNSW dist 24.301, grid search dist 15.381
HNSW dist 27.855, grid search dist 15.560
HNSW dist 27.523, grid search dist 15.758
HNSW dist 28.050, grid search dist 15.398
HNSW dist 24.899, grid search dist 17.999
HNSW dist 29.895, grid search dist 16.212
HNSW dist 25.515, grid search dist 14.696
HNSW dist 27.586, grid search dist 15.141
HNSW dist 25.312, grid search dist 15.652
HNSW dist 22.540, grid search dist 17.517
HNSW dist 23.867, grid search dist 18.954

points.csv

zgornel commented 3 years ago

Have you tried varying the values for M and ef? Edit: just saw that it was actually the case.

How does classic knn work on the dataset? Hnsw has in its test a compatibility requirement i.e. order should match

tverho commented 3 years ago

I modified the comparison test case found in the tests folder to use my data set. The k=1 tests fail, but the k=10 and k=20 test pass, meaning that on average, 90% of the actual nearest neighbors are found. However, when I add a condition to check that at least one actual NN is found for every query, it fails. The same condition passes in the original test using randomized data.

I guess the challenge in my data set is that the points are not uniformly distributed, but there are large empty areas.

Here's the test case

using NearestNeighbors
using HNSW
using StaticArrays
using Statistics
using DelimitedFiles
using Test

@testset "Compare To NearestNeighbors.jl with nonuniform data" begin
    dim = 3
    num_queries = 1000
    data = readdlm("points.csv")
    data = [SVector{dim}(data[i,:]) for i in 1:size(data,1)]
    cell = Float32[200, 200, 400]
    num_elements = length(data)
    tree = KDTree(data)
    queries = [SVector{dim}(rand(Float32, dim).*cell) for n ∈ 1:num_queries]
    @testset "M=$M, K=1" for M ∈ [5, 10]
        k = 1
        efConstruction = 20
        ef = 20
        realidxs, realdists = knn(tree, queries, k)

        hnsw = HierarchicalNSW(data; efConstruction=efConstruction, M=M, ef=ef)
        add_to_graph!(hnsw)
        idxs, dists = knn_search(hnsw, queries, k)

        ratio = mean(map(idxs, realidxs) do i,j
                        length(i ∩ j) / k
                     end)
        @test ratio > 0.99
    end

    @testset "Large K, low M=$M" for M ∈ [5,10]
        efConstruction = 100
        ef = 50
        hnsw = HierarchicalNSW(data; efConstruction=efConstruction, M=M, ef=ef)
        add_to_graph!(hnsw)
        @testset "K=$K" for K ∈ [10,20]
            realidxs, realdists = knn(tree, queries, K)
            idxs, dists = knn_search(hnsw, queries, K)

            ratios = map(idxs, realidxs) do i,j
                            length(i ∩ j) / K
                         end
            ratio = mean(ratios)
            @test ratio > 0.9
            @test all(ratios .> 0.0)
        end
    end
    @testset "Low Recall Test" begin
        k = 1
        efConstruction = 20
        M = 5
        hnsw = HierarchicalNSW(data; efConstruction=efConstruction, M=M)
        check_counter = 0
        add_to_graph!(hnsw) do i
            check_counter += i
        end
        @test check_counter == (1 + num_elements) * num_elements ÷ 2

        set_ef!(hnsw, 2)
        realidxs, realdists = knn(tree, queries, k)
        idxs, dists = knn_search(hnsw, queries, k)

        recall = mean(map(idxs, realidxs) do i,j
                        length(i ∩ j) / k
                     end)
        @test recall > 0.6
    end
end
deahhh commented 1 year ago

I replaced PeriodicEuclidean(cell) with Euclidean(), and nothing output. Actually, PeriodicEuclidean is not a distance.