scikit-learn-contrib / metric-learn

Metric learning algorithms in Python
http://contrib.scikit-learn.org/metric-learn/
MIT License
1.38k stars 231 forks source link

Adjustment of the validation of the number of target neighbors #353

Open JanekBlankenburg opened 1 year ago

JanekBlankenburg commented 1 year ago

Before the actual optimization process, it is checked whether the parameters are valid. In the lines 175 - 177 it is checked if the chosen k is valid in the context of the training data. According to the definition of LMNN by Weinberger et al. each class must have at least k+1 elements, so that there are at least k target neighbors for each data point. In the implementation, however, it is only checked whether self.n_neighbors<= required_k (in fact the code checks the opposite in order to throw an error), where required_k is the number of elements of the smallest class. This check indicates that the choice of k is valid for a class that has exactly k elements, which shouldn’t be the case. However, this leads to selecting a point as its own target neighbor, if this small class. For the determination of the target neighbors, a distance matrix of all points within the class is computed. To prevent that the point itself is recognized as nearest neighbor, the diagonal of this matrix is set to infinity. If a class has only k elements, all elements of the class are chosen as target neighbors, including the current point itself (even if it has a distance of infinity to itself according to the distance matrix). This results in each point of such a class effectively having one target neighbor less than classes with more training data, which can have unintended influences on the final transformation depending on the dataset used.

To prevent this, it is sufficient to adjust the validation so thatself.n_neighbors < required_k must apply.

perimosocordiae commented 1 year ago

Thanks for sending a PR, @JanekBlankenburg! The test suite is showing some errors, some of which appear to be simple matters of changing the expected exception message, but some look more substantial. Can you take a look?