scikit-learn-contrib / scikit-matter

A collection of scikit-learn compatible utilities that implement methods born out of the materials science and chemistry communities
https://scikit-matter.readthedocs.io/en/v0.2.0/
BSD 3-Clause "New" or "Revised" License
76 stars 20 forks source link

Zero scores result in repeated selection and wrong scores at least for FPS #206

Open agoscinski opened 1 year ago

agoscinski commented 1 year ago

Detected by @PicoCentauri

Problem

import numpy as np
from skmatter.feature_selection import FPS

np.random.seed(0)
n_samples = 10
n_features = 15
X = np.random.rand(n_samples , n_features )
X[:, 3] = np.random.rand(10) * 1e-13
X[:, 4] = np.random.rand(10) * 1e-13
selector_problem = FPS(n_to_select=len(X.T)).fit(X)
print(selector_problem.selected_idx_)
print(selector_problem.get_select_distance())
print()

# this selector does not have the problem because we stop before the score threshold
selector = FPS(n_to_select=len(X.T), score_threshold=1e-9).fit(X)
print(selector.selected_idx_)
print(selector.get_select_distance())

Out:

[ 0  8  3  6 14  2 13  9  7 11  1 10 12  5  8]
[           inf 1.77635684e-15 2.16390745e+00 1.62400552e+00
 1.43445978e+00 1.23482177e+00 1.03370164e+00 9.21863706e-01
 7.95155761e-01 7.87817521e-01 7.37837489e-01 6.52674372e-01
 6.11845170e-01 5.65607255e-01 1.77635684e-15]
/home/alexgo/code/scikit-matter/src/skmatter/_selection.py:210: UserWarning: Score threshold of 1e-09 reached.Terminating search at 14 / 15.
  warnings.warn(
[ 0  8  3  6 14  2 13  9  7 11  1 10 12]
[       inf 2.75832232 2.16390745 1.62400552 1.43445978 1.23482177
 1.03370164 0.92186371 0.79515576 0.78781752 0.73783749 0.65267437
 0.61184517]

You can see in the first selector that 8 is reselected and sets the wrong score. This is because we do not filter for not selected points in the GreedySelector base class when choosing the next point. https://github.com/scikit-learn-contrib/scikit-matter/blob/d56ccbd4648ad90299b27cb5c23ecd3b39e4d12a/src/skmatter/_selection.py#L371 So when the scores are all (numerical) zero, then points that have been already selected can be reselected.

Solution

One could add selected_idx_ to the GreedySelector base class and change the argmax in the function above that it only considers the not selected indices.