facebookresearch / faiss

A library for efficient similarity search and clustering of dense vectors.
https://faiss.ai
MIT License
31.21k stars 3.63k forks source link

inquiry related to DistanceComputer #3415

Open anupsingh15 opened 5 months ago

anupsingh15 commented 5 months ago

Is there a way to compute the asymmetric distance of a query with respect to encoded (using PQ) database items? I figured out that I needed to use the get_FlatCodesDistanceComputer method of the index. Hereunder is my attempt:

x = np.random.randn(10000, 128)
x_norm = x/np.linalg.norm(x, axis=-1).reshape(-1,1)
index_pq = faiss.IndexPQ(128,8,7)
index_pq.train(x_norm)
index_pq.add(x_norm)
codes = index_pq.sa_encode(x_norm)

dist_comp = index_pq.get_FlatCodesDistanceComputer()
dist_comp.set_query = x_norm[0]
dist_comp.distance_to_code(codes[:5]) # to find distance wrt first 5 encoded database items

However, I get the following error:

TypeError: in method 'FlatCodesDistanceComputer_distance_to_code', argument 2 of type 'uint8_t const *'

Could you please let me know what I am missing?

Thanks.

mdouze commented 5 months ago

This is because the distance_to_code function does not have a simplified wrapping layer. It can still be called using swig_ptr:

dist_comp = index_pq.get_FlatCodesDistanceComputer()
dist_comp.set_query(faiss.swig_ptr(x_norm[0]))
dist_comp.distance_to_code(faiss.swig_ptr(codes[:5])) 

This computes one distance at a time (because it's a low-level function).

anupsingh15 commented 5 months ago

Thanks for your reply.

I get the following error when executing: dist_comp.set_query(faiss.swig_ptr(x_norm[0]))

TypeError: in method 'DistanceComputer_set_query', argument 2 of type 'float const *'

I also tried the following, which does not give an error, but unsure why does it output only 0. I expected it to output 5 distances since I compare a query with 5 different encodings :

dist_comp = index_pq.get_FlatCodesDistanceComputer()
dist_comp.set_query = faiss.swig_ptr(x_norm[0])
dist_comp.distance_to_code(faiss.swig_ptr(codes[:5]))