This is a quick fix to a problem that was making fail the following example for diagonal MMC:
In [9]: from sklearn.datasets import fetch_lfw_pairs
...: from sklearn.model_selection import cross_validate, train_test_split
...: from metric_learn import MMC
...: pairs, y_pairs = [fetch_lfw_pairs()[key] for key in ['pairs', 'target']]
...: pairs, _, y_pairs, _ = train_test_split(pairs, 2*y_pairs-1)
...: pairs = pairs.reshape(pairs.shape[0], 2, -1)
...: mmc = MMC(diagonal=True)
...: mmc.fit(pairs, y_pairs)
what happens is that obj_previous (see here: https://github.com/scikit-learn-contrib/metric-learn/blob/899ef47889426cc2a6ffa606ba43b892af7b48da/metric_learn/mmc.py#L207) was such a big number (of order 10**20) that adding +1 made it stay the same number (see picture) and therefore the strict comparison failed to start the loop and therefore w_previous was unreferenced. Replacing obj_previous +1 by np.inf fixed the pb (it's a quick fix to make it work, but I don't really know if it's the final answer, what if obj_previous == np.inf too ? The comparison will then fail again. Should we detect this case and throw an error ?) I guess for the next release we can merge this as is though, and maybe investigate later ?
This is a quick fix to a problem that was making fail the following example for diagonal MMC:
what happens is that
obj_previous
(see here: https://github.com/scikit-learn-contrib/metric-learn/blob/899ef47889426cc2a6ffa606ba43b892af7b48da/metric_learn/mmc.py#L207) was such a big number (of order 10**20) that adding +1 made it stay the same number (see picture) and therefore the strict comparison failed to start the loop and thereforew_previous
was unreferenced. Replacingobj_previous +1
bynp.inf
fixed the pb (it's a quick fix to make it work, but I don't really know if it's the final answer, what ifobj_previous == np.inf
too ? The comparison will then fail again. Should we detect this case and throw an error ?) I guess for the next release we can merge this as is though, and maybe investigate later ?