microsoft / DiskANN

Graph-structured Indices for Scalable, Fast, Fresh and Filtered Approximate Nearest Neighbor Search
Other
1.02k stars 208 forks source link

[BUG] Distance return for inner_product metric is not expected #539

Open wdongyu opened 4 months ago

wdongyu commented 4 months ago

Expected Behavior

When inner_product is selected as metric, it compare by L2 distance internally, convert to inner_product distance correctly and return

Actual Behavior

When inner_product is selected as metric, it compare by L2 distance internally by adding an extra dimension. Let's use U and V as the two input vectors, and Q as the query vector.

Distance between U and V will compute like this (where m represents the max_norm for all data): QianJianTec1712735290964

and distance between U and Q will compute like this(where m represents the max_norm for all data, and m_q represents the norm for the query) : QianJianTec1712736147631

Error

There exist two problems:

  1. To return the negative inner_product distance, when constructing the graph, we should adding two times of the product of the extra dimension for unnormalized data, that's like: QianJianTec1712738409932

    for normalized data, the extra dimension is zero for all input data, so it's ok to ignore it.

  2. To return the negative inner_product distance, when querying, the code current flip the distance by distances[i] = (-distances[i]), but instead, we should convert the distance by distances[i] = (distances[i] - 2) / 2 like following : QianJianTec1712739169957

Example Code

from tempfile import TemporaryDirectory

import diskannpy
import numpy as np

query = [1, 1]
search_space = [
    [10, 0], 
    [1, 0.1],  
]

with TemporaryDirectory() as tmpdir:
    diskannpy.build_disk_index(
        np.array(search_space, dtype=np.float32),
        "mips",
        tmpdir,
        75,
        60,
        1,
        1,
        2,
    )
    index = diskannpy.StaticDiskIndex(tmpdir, 2, 0)
    res = index.search(np.array(query, dtype=np.float32), 2, 2)
    assert res[0].shape == (2,)
    print(res)

Dataset Description

Please tell us about the shape and datatype of your data, (e.g. 128 dimensions, 12.3 billion points, floats)

Your Environment