facebookresearch / faiss

A library for efficient similarity search and clustering of dense vectors.
https://faiss.ai
MIT License
29.54k stars 3.49k forks source link

JensenShannon distance seems to actually be JensenShannon divergence #2947

Open dleviminzi opened 1 year ago

dleviminzi commented 1 year ago

Summary

Platform

Faiss version: 43d86e30736ede853c384b24667fc3ab897d6ba9

Interface:

Currently, in DistanceUtils.cuh the JensenShannon distance reduce step is:

    __host__ __device__ float reduce() {
        return 0.5 * dist;
    }

This just gives the divergence. To calculate the JensenShannon distance, I think it would need to be:

    __host__ __device__ float reduce() {
        return sqrt(0.5 * dist);
    }

I'm not particularly familiar with this subject so I don't know if it is just a matter of the name being inaccurate or an actual problem with the implementation. I noticed this because given the same probability distributions, the output from scipy's JensenShannon distance was different.

edit: I thought about making the change and seeing if tests broke, but for some reason my laptop keeps crashing during compilation. I will test later on a different computer.

mlomeli1 commented 1 year ago

hi @dleviminzi , yes, that is correct. Thank you for flagging, we will make this name change so it's consistent with the definition of the Jensen-Shannon distance

dleviminzi commented 1 year ago

hi @dleviminzi , yes, that is correct. Thank you for flagging, we will make this name change so it's consistent with the definition of the Jensen-Shannon distance

Will it remain as the divergence and just change name or will you add the sqrt?

mlomeli1 commented 1 year ago

Oops sorry, I meant to say that we will either change the name to divergence or add the square root to the method so it corresponds to the distance @dleviminzi . I think the latter is preferable but I will discuss it with the core developers to see if they agree.

dleviminzi commented 1 year ago

Sounds good!

mdouze commented 11 months ago

We leave out sqrt since it is an expensive monotonous functions that does not change the ordering of results. This is also the case with the L2 metric which striclty speaking is the squared L2. Since in the code it's just called METRIC_JensenShannon I think it's fine to clearly specify in the docs that it's the caller's job to do the sqrt.