scikit-learn-contrib / metric-learn

Metric learning algorithms in Python
http://contrib.scikit-learn.org/metric-learn/
MIT License
1.4k stars 234 forks source link

[MRG] replace +1 by inf in MMC diag #297

Closed wdevazelhes closed 4 years ago

wdevazelhes commented 4 years ago

This is a quick fix to a problem that was making fail the following example for diagonal MMC:

In [9]:  from sklearn.datasets import fetch_lfw_pairs 
   ...:  from sklearn.model_selection import cross_validate, train_test_split 
   ...:  from metric_learn import MMC 
   ...:  pairs, y_pairs = [fetch_lfw_pairs()[key] for key in ['pairs', 'target']] 
   ...:  pairs, _, y_pairs, _ = train_test_split(pairs, 2*y_pairs-1) 
   ...:  pairs = pairs.reshape(pairs.shape[0], 2, -1) 
   ...:  mmc = MMC(diagonal=True)
   ...:  mmc.fit(pairs, y_pairs)

what happens is that obj_previous (see here: https://github.com/scikit-learn-contrib/metric-learn/blob/899ef47889426cc2a6ffa606ba43b892af7b48da/metric_learn/mmc.py#L207) was such a big number (of order 10**20) that adding +1 made it stay the same number (see picture) and therefore the strict comparison failed to start the loop and therefore w_previous was unreferenced. Replacing obj_previous +1 by np.inf fixed the pb (it's a quick fix to make it work, but I don't really know if it's the final answer, what if obj_previous == np.inf too ? The comparison will then fail again. Should we detect this case and throw an error ?) I guess for the next release we can merge this as is though, and maybe investigate later ?

image

perimosocordiae commented 4 years ago

Thanks, merged.