jolars / slopecd

4 stars 2 forks source link

Oracle not convering on rcv1 #41

Closed jolars closed 2 years ago

jolars commented 2 years ago

The oracle method is not converging on rcv1 data set for some reason, at least not with reg = 0.02. Try this example:

import matplotlib.pyplot as plt
import numpy as np
from benchopt.datasets import make_correlated_data
from scipy import stats

from slope.data import get_data
from slope.solvers import hybrid_cd, oracle_cd
from slope.utils import dual_norm_slope

dataset = "rcv1.binary"
if dataset == "simulated":
    X, y, _ = make_correlated_data(n_samples=10, n_features=10, random_state=0)
else:
    X, y = get_data(dataset)

fit_intercept = False

randnorm = stats.norm(loc=0, scale=1)
q = 0.1
reg = 0.02

alphas_seq = randnorm.ppf(1 - np.arange(1, X.shape[1] + 1) * q / (2 * X.shape[1]))

alpha_max = dual_norm_slope(X, (y - fit_intercept * np.mean(y)) / len(y), alphas_seq)

alphas = alpha_max * alphas_seq * reg

max_epochs = 10000
max_time = np.inf
tol = 1e-4

beta_cd, intercept_cd, primals_cd, gaps_cd, time_cd = hybrid_cd(
    X,
    y,
    alphas,
    fit_intercept=fit_intercept,
    max_epochs=max_epochs,
    verbose=True,
    tol=tol,
    max_time=max_time,
    cluster_updates=True,
)

beta_oracle, intercept_oracle, primals_oracle, gaps_oracle, time_oracle = oracle_cd(
    X,
    y,
    alphas,
    fit_intercept=fit_intercept,
    max_epochs=max_epochs,
    verbose=True,
    tol=tol,
    max_time=max_time,
)

primals_star = np.min(np.hstack((np.array(primals_cd), np.array(primals_oracle))))

plt.clf()

plt.semilogy(time_cd, primals_cd - primals_star, label="cd")
plt.semilogy(time_oracle, primals_oracle - primals_star, label="cd_oracle")

plt.xlabel("Time (s)")

# plt.semilogy(np.arange(len(gaps_cd))*10, gaps_cd, label="cd")
# plt.xlabel("Epoch")

plt.ylabel("suboptimality")
plt.legend()
plt.title(dataset)
plt.show(block=False)

Gets me:

Epoch: 9771, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9781, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9791, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9801, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9811, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9821, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9831, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9841, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9851, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9861, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9871, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9881, loss: 0.18407451740050448, gap: 1.06e-01
Epoch: 9891, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9901, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9911, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9921, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9931, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9941, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9951, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9961, loss: 0.18407451740050443, gap: 1.06e-01
Epoch: 9971, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9981, loss: 0.18407451740050446, gap: 1.06e-01
Epoch: 9991, loss: 0.18407451740050446, gap: 1.06e-01

etc for the oracle method.