yngvem / group-lasso

Group Lasso implementation following the scikit-learn API
MIT License
105 stars 32 forks source link

why the three number in 'yhat' is same #32

Closed SikangSHU closed 2 years ago

SikangSHU commented 2 years ago

Hello! I don't know why the three number in 'yhat' is same and 'w_hat' is all zero. I'd like to get 'w_hat' with a few non-zero numbers so that I can know which group of 'X' is used. Can you help me with that? Thank you very much!

###############################################################################
# Setup
# -----

import matplotlib.pyplot as plt
import numpy as np
from numpy import linalg
from skimage import io
from skimage.color.colorconv import _prepare_colorarray
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error

from group_lasso import GroupLasso

np.random.seed(0)     
GroupLasso.LOG_LOSSES = True

###############################################################################
# Set dataset parameters
# ----------------------
group_sizes = [3, 3, 2, 2, 2]
active_groups = np.ones(5)
active_groups[:3] = 2
np.random.shuffle(active_groups)
np.random.shuffle(active_groups)
groups = np.concatenate(
    [size * [i] for i, size in enumerate(group_sizes)]
).reshape(-1, 1)
num_coeffs = sum(group_sizes)
num_datapoints = 3

print("group_sizes:", group_sizes)
print("active_groups:", active_groups)
print("groups:", groups.shape)
print("num_coeffs:", num_coeffs)
print("____________________________________________")

###############################################################################
# Generate data matrix
# --------------------
X = np.array([[0.571, 0.095, 0.767, 0.095, 0.105, 0.767, 0.571, 0.767, 0.095, 0.767, 0.105, 0.767],
              [0.584, 0.258, 0.576, 0.258, 0.758, 0.576, 0.584, 0.576, 0.258, 0.576, 0.758, 0.576],
              [0.577, 0.961, 0.284, 0.961, 0.644, 0.284, 0.577, 0.284, 0.961, 0.284, 0.644, 0.284]])

print("X:", X.shape)
print("____________________________________________")

###############################################################################
# Generate coefficients
# ---------------------k
w = np.concatenate(
    [
        np.random.standard_normal(group_size) * is_active
        for group_size, is_active in zip(group_sizes, active_groups)
    ]
)
w = w.reshape(-1, 1)
true_coefficient_mask = w != 0
intercept = 2

print("w:", w.shape)
print("true_coefficient_mask:", true_coefficient_mask.sum())
print("____________________________________________")

###############################################################################
# Generate regression targets
# ---------------------------

y_true = X @ w
y = np.array([[-0.17997138],
              [-0.15219182],
              [-0.17062552]])
y_true = X @ w
print("y:", y)
MSE1 = mean_squared_error(y, y_true)
print("MSE_yt_y:", MSE1)
print("____________________________________________")

###############################################################################
# Generate estimator and train it
# -------------------------------
gl = GroupLasso(
    groups=groups,
    group_reg=5,
    l1_reg=2,
    frobenius_lipschitz=True,
    scale_reg="inverse_group_size",
    subsampling_scheme=1,
    supress_warning=True,
    n_iter=1000,
    tol=1e-3,
)
gl.fit(X, y)

###############################################################################
# Extract results and compute performance metrics
# -----------------------------------------------

# Extract info from estimator
yhat = gl.predict(X)
sparsity_mask = gl.sparsity_mask_
w_hat = gl.coef_

print("yhat:", yhat)
print("w_hat:", w_hat.sum())
print("sparsity_mask:", sparsity_mask)
print("____________________________________________")

# Compute performance metrics
R2 = r2_score(y, yhat)
MSE_y_yh = mean_squared_error(y, yhat)
print("MSE_y_yh:", MSE_y_yh)
print("____________________________________________")

And the result of the program after running is as follows.

group_sizes: [3, 3, 2, 2, 2]
active_groups: [2. 2. 2. 1. 1.]
groups: (12, 1)
num_coeffs: 12
____________________________________________
X: (3, 12)
____________________________________________
w: (12, 1)
true_coefficient_mask: 12
____________________________________________
y: [[-0.17997138]
 [-0.15219182]
 [-0.17062552]]
MSE_yt_y: 55.67436677644974
____________________________________________
yhat: [-0.16644355 -0.16644355 -0.16644355]
w_hat: 0.0
sparsity_mask: [False False False False False False False False False False False False]
____________________________________________
MSE_y_yh: 0.0001345342966391801
____________________________________________
yngvem commented 2 years ago

You probably have too much regularisation