yngvem / group-lasso

Group Lasso implementation following the scikit-learn API
MIT License
105 stars 32 forks source link

``LogisticGroupLasso`` is not equivalent to ``scikit-learn`` ``LogisticRegression`` with groups of size 1 #36

Open Badr-MOUFAD opened 1 year ago

Badr-MOUFAD commented 1 year ago

Problem

LogisticGroupLasso doesn't give the same regression coefficient as scikit-learn LogisticRegression when fitted on groups with size 1.

Expected behavior

Group logistic regression with groups of size one should be equivalent to logistic regression.

Script to reproduce

(click to expend) ```python import numpy as np from sklearn.linear_model import LogisticRegression from group_lasso import LogisticGroupLasso n_samples, n_features = 20, 60 # generate dummy data rng = np.random.RandomState(123) X = rng.randn(n_samples, n_features) y = np.sign(rng.randn(n_samples)) # max regularization parameter alpha_max = np.linalg.norm(X.T @ y, ord=np.inf) / (2 * n_samples) alpha = 0.1 * alpha_max # fit scikit-learn log reg sk_model = LogisticRegression( penalty='l1', C=1/(n_samples * alpha), # scikit-learn uses an un-normalized loss fit_intercept=False, tol=1e-9, solver='liblinear' ).fit(X, y) sk_coef = sk_model.coef_.flatten() # group log reg yn_model = LogisticGroupLasso( group_reg=alpha, groups=np.arange(n_features), fit_intercept=False, subsampling_scheme=None, tol=1e-9, l1_reg=0., ).fit(X, y) coef_ = yn_model.coef_ yn_coef = coef_[:, 1] np.testing.assert_allclose(yn_coef, sk_coef) ```

output

(click to expend) ```shell AssertionError: Not equal to tolerance rtol=1e-07, atol=0 Mismatched elements: 10 / 60 (16.7%) Max absolute difference: 0.96593495 Max relative difference: 1. x: array([ 0. , -0. , -0.010529, 0. , 0. , 0.43878 , 0. , 0.337345, 0.15179 , -0. , 0.066626, 0. , -0. , 0. , -0. , 0. , 0. , -0. ,... y: array([ 0. , 0. , -0.034694, 0. , 0. , 1.111946, 0. , 0.840564, 0.309066, 0. , 0.15613 , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,... ```
mathurinm commented 1 year ago

Hi @yngvem, do you have any feedback on this?