Closed PABannier closed 3 years ago
Thanks for the reproducing example, it was easy to run.
Just to mention, you can make it even better for me by being more minimal, i.e. getting the value of alpha, and the CV fold for which it breaks, hence removing one function:
import numpy as np
from numpy.linalg import norm
from sklearn.model_selection import KFold
from celer import MultiTaskLasso
N_SAMPLES = 10
N_FEATURES = 15
N_TASKS = 5
def compute_alpha_max(X, Y):
return np.max(norm(X.T @ Y, axis=1)) / X.shape[0]
def fit_reweighted_lasso(X, Y, alpha, n_iter, warm_start, tol):
def penalty(u): return 1 / (
2 * np.sqrt(np.linalg.norm(u, axis=1)) + np.finfo(float).eps
)
regressor = MultiTaskLasso(
alpha, fit_intercept=False, warm_start=warm_start, tol=tol, verbose=1)
w = np.ones(N_FEATURES)
for k in range(n_iter):
print(f"Reweighting number {k}")
X_w = X / w[np.newaxis, :]
regressor.fit(X_w, Y)
coef = (regressor.coef_ / w).T
w = penalty(coef)
return coef
if __name__ == "__main__":
rng = np.random.default_rng(42)
X = rng.random((N_SAMPLES, N_FEATURES))
Y = rng.random((N_SAMPLES, N_TASKS))
max_alpha = compute_alpha_max(X, Y)
alphas = np.geomspace(max_alpha, max_alpha / 10, 30)
alpha = alphas[5]
kf = KFold()
trn_index, val_index = list(kf.split(X, Y))[3]
X_train, Y_train = X[trn_index, :], Y[trn_index, :]
fit_reweighted_lasso(X_train, Y_train, alpha, 3, True, 1e-6)
fixed in #202
Hello @mathurinm !
After the patch you made to fix the warm_start issue (#199 ), MultiTaskLasso no longer raises an error.
However, after many trials, when using
warm_start=True
in MultiTaskLasso, the solver fails to converge and issues a warning:ConvergenceWarning: Objective did not converge: duality gap: 0.0055864603018013215, tolerance: 1.8021110008703545e-06. Increasing
tolmay make the solver faster without affecting the results much. Fitting data with very small alpha causes precision issues.
I tried changing the tolerance to 1e-3 but the problem persists. The warning is issued when fitting with alpha_max * 0.8.
Note: when turning warm_start to False, the solver converges normally.
Find hereafter a snippet to reproduce the error: