Closed jolars closed 2 years ago
The oracle method is not converging on rcv1 data set for some reason, at least not with reg = 0.02. Try this example:
reg = 0.02
import matplotlib.pyplot as plt import numpy as np from benchopt.datasets import make_correlated_data from scipy import stats from slope.data import get_data from slope.solvers import hybrid_cd, oracle_cd from slope.utils import dual_norm_slope dataset = "rcv1.binary" if dataset == "simulated": X, y, _ = make_correlated_data(n_samples=10, n_features=10, random_state=0) else: X, y = get_data(dataset) fit_intercept = False randnorm = stats.norm(loc=0, scale=1) q = 0.1 reg = 0.02 alphas_seq = randnorm.ppf(1 - np.arange(1, X.shape[1] + 1) * q / (2 * X.shape[1])) alpha_max = dual_norm_slope(X, (y - fit_intercept * np.mean(y)) / len(y), alphas_seq) alphas = alpha_max * alphas_seq * reg max_epochs = 10000 max_time = np.inf tol = 1e-4 beta_cd, intercept_cd, primals_cd, gaps_cd, time_cd = hybrid_cd( X, y, alphas, fit_intercept=fit_intercept, max_epochs=max_epochs, verbose=True, tol=tol, max_time=max_time, cluster_updates=True, ) beta_oracle, intercept_oracle, primals_oracle, gaps_oracle, time_oracle = oracle_cd( X, y, alphas, fit_intercept=fit_intercept, max_epochs=max_epochs, verbose=True, tol=tol, max_time=max_time, ) primals_star = np.min(np.hstack((np.array(primals_cd), np.array(primals_oracle)))) plt.clf() plt.semilogy(time_cd, primals_cd - primals_star, label="cd") plt.semilogy(time_oracle, primals_oracle - primals_star, label="cd_oracle") plt.xlabel("Time (s)") # plt.semilogy(np.arange(len(gaps_cd))*10, gaps_cd, label="cd") # plt.xlabel("Epoch") plt.ylabel("suboptimality") plt.legend() plt.title(dataset) plt.show(block=False)
Gets me:
Epoch: 9771, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9781, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9791, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9801, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9811, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9821, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9831, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9841, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9851, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9861, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9871, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9881, loss: 0.18407451740050448, gap: 1.06e-01 Epoch: 9891, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9901, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9911, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9921, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9931, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9941, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9951, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9961, loss: 0.18407451740050443, gap: 1.06e-01 Epoch: 9971, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9981, loss: 0.18407451740050446, gap: 1.06e-01 Epoch: 9991, loss: 0.18407451740050446, gap: 1.06e-01
etc for the oracle method.
The oracle method is not converging on rcv1 data set for some reason, at least not with
reg = 0.02
. Try this example:Gets me:
etc for the oracle method.