yngvem / group-lasso

Group Lasso implementation following the scikit-learn API
MIT License
105 stars 32 forks source link

Comparison with the R package grplasso and imitation of some functions #38

Open normanius opened 1 year ago

normanius commented 1 year ago

Hi Yngve

Have you ever examined how this package's functionality relates to similar implementations in R (grplasso, gglasso?

Below, I'm sharing some code with you that I've created recently. Inspired by the grplasso package in R, I've implemented (simplified versions of) lambdamax and plot.grplasso.

I've restricted myself to the logistic regression case, using a dataset "colon" provided by gglasso. See the attached .csv that contains the demo data.

I don't know what your plans are for your package, but I'm sharing the code so you or others can use it. I've noticed that are some difference when comparing grplasso with group-lasso, but did not have a thorough look at everything.

Data: colon.csv

Result: evolution plot. It gives the number of groups selected for a decreasing value of $lambda$.

evolution-plot


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from group_lasso import LogisticGroupLasso

np.random.seed(0)
LogisticGroupLasso.LOG_LOSSES = True

def compute_grplr(X, y, param=0.05):
  gl = LogisticGroupLasso(
      groups=groups,
      group_reg=param,
      n_iter=100000,
      l1_reg=0,
      scale_reg="inverse_group_size",
      subsampling_scheme=None,
      supress_warning=True,
  )
  gl.fit(X, y)
  return gl

def plot_grplr(gl):
  # Plot the results
  coef = gl.coef_[:, 1] - gl.coef_[:, 0]
  plt.figure()
  plt.plot(coef / np.linalg.norm(coef), ".", 
           label="Estimated weights")

  plt.figure()
  plt.plot(gl.losses_)

  plt.show()

def lambdamax(X, y, hi=1, lo=0.01, mode="log", 
              tol=1e-3, iter_max=10, **kwargs):
  """
  Stupid imitation of grplasso.lambdamax:
  Get the value of the penalty parameter lambda when
  the first penalized parameter group enters the model.
  Algorithm: bisect
  """
  def count_nonzero_coefs(X, y, lmbda, **kwargs):
    gl = compute_grplr(X, y, param=lmbda, **kwargs)
    coef = gl.coef_[:, 1] - gl.coef_[:, 0]
    return sum(coef!=0)

  def bisect(func, x_min, x_max, x_func=None,
             tol=1e-7, iter_max=None, **kwargs):
    # Problem: we need to solve a "degenerate" root finding problem
    # https://stackoverflow.com/questions/76168787
    y_min = func(x_min, **kwargs)
    if y_min > 0:
      msg = "Warning: no solution as y_min>0, with x_min=%f."
      print(msg % x_min)
      return x_min
    y_max = func(x_max, **kwargs)
    if y_max <= 0:
      msg = "Warning: no solution as y_max<=0, with x_max=%f."
      print(msg % x_max)
      return x_max
    if tol is None and iter_max is None:
        tol = 1e-7
    if x_func is None:
      x_func = lambda x0, x1: (x1+x0)/2

    from itertools import count, islice
    x_last = np.infty
    for cnt in islice(count(), iter_max):
      x_new = x_func(x_min, x_max)
      y_new = func(x_new, **kwargs)
      if y_new<=0:
        x_min = x_new
      else:
        x_max = x_new
      if (tol is not None) and abs(x_last - x_new) < tol:
        break
      x_last = x_new
    neg_direction = x_min > x_max
    if neg_direction:
      return x_max if y_new>0 else x_min
    else:
      return x_min if y_new>0 else x_max

  if mode=="log":
    # Geometric mean
    x_func = lambda x0, x1: np.sqrt(x0*x1)
  else:
    # Arithmetic mean
    x_func = lambda x0, x1: (x0+x1)/2

  func = lambda x, **kwargs: count_nonzero_coefs(X, y, lmbda=x, **kwargs)
  lmbda = bisect(func=func, x_min=hi, x_max=lo, x_func=x_func,
                 tol=tol, iter_max=iter_max, **kwargs)
  return lmbda

def plot_evolution(X, y):
  coefs = []
  lambda_start = lambdamax(X, y, hi=1, lo=0.01)
  exp_start = np.log10(lambda_start)
  lambdas = np.logspace(exp_start, -2, 20)
  for i, lmbda in enumerate(lambdas):
    print("Step i=%d, lambda=%.3f" % (i+1, lmbda))
    gl = compute_grplr(X, y, param=lmbda)
    coef = gl.coef_[:, 1] - gl.coef_[:, 0]
    coefs.append(coef)
  coefs = np.vstack(coefs)
  fig, ax = plt.subplots()
  ax.plot(lambdas, coefs)
  ax.set_xscale("log")
  ax.invert_xaxis()
  ax.set_xlabel("log(lambda)")
  ax.set_ylabel("coeff")
  plt.show()

df = pd.read_csv("colon.csv", index_col=[0])
y = df["y"].copy()
X = df.loc[:,df.columns != "y"].copy()
groups = np.repeat(range(1,21), 5)

gl = compute_grplr(X, y, param=0.1)
#gl = compute_grplr(X, y, param=0.05)
plot_grplr(gl)
plot_evolution(X, y)

Comparison with the R package grplasso. Same data, similar code. The results are not exactly the same, but I have not investigated why that is.

library(pacman)
pacman::p_install("grplasso")

# Load the colon data (from gglasso)
data(colon)
# Define group index
group <- rep(1:20, each=5)
# y must take values 0 and 1
colon$y <- (colon$y+1)/2

# Determine the value of the penalty parameter lambda when 
# the first penalized parameter group enters the model.
lambda <- lambdamax(x=colon$x, y=colon$y, model = LogReg(), 
                    index=group, standardize = TRUE)
# Create a sequence of lambda values to sample...
lambda <- lambda * 0.8^seq(0,8,0.1)

# Fit a model using the specified lambda sequence.
# Equation: y ~ . Means: y against all other columns.
fit <- grplasso(x=colon$x, y=colon$y, model = LogReg(), index=group,
                lambda = lambda, standardize = TRUE)

# With some explicit settings... (trace=0 for quite evaluation)
fit <- grplasso(x=colon$x, y=colon$y, model = LogReg(), index=group, 
                lambda = lambda, standardize = TRUE,
                control = grpl.control(trace = 0, inner.loops = 10,
                                       update.every = 1, 
                                       update.hess = "lambda"))

# Plot the solution path of the group lasso regression.
plot(fit, log = "x")

Evolution plot in R:

evolution-plot-r