jpvert / apg

Accelerated proximal gradient in R
7 stars 6 forks source link

Implementing Ordered Weighted L1-norm proximity operator #3

Closed JoaoSantinha closed 4 years ago

JoaoSantinha commented 4 years ago

Hi,

Thanks for really nice package.

I was trying implement the OWL norm but I wanted to check with you if you see something wrong on the implementation and get your help if possible as the weights and intercept are always 0.

glm.apg <- function(x, y, family=c("gaussian", "binomial", "survival"), penalty=c("elasticnet", "isotonic", "boundednondecreasing", "owl"), lambda=1, intercept=TRUE, opts=list()) {
    family <- match.arg(family)
    if (family=="survival") {
        intercept=FALSE
    }
    penalty <- match.arg(penalty)
    y <- drop(y)
    np <- dim(x)
    if (is.null(np) | (np[2] <= 1))
        stop("x should be a matrix with 2 or more columns")
    n = as.integer(np[1])
    p = as.integer(np[2])
    dimy = dim(y)
    nrowy = ifelse(is.null(dimy), length(y), dimy[1])
    if (nrowy != n)
        stop(paste("number of observations in y (", nrowy, ") not equal to the number of rows of x (",
                   n, ")", sep = ""))
    vnames = colnames(x)
    if (is.null(vnames))
        vnames = paste("V", seq(n), sep = "")

    # Gradient of the smooth part
    gradG <- switch(family, gaussian = grad.quad, binomial = grad.logistic, survival = grad.rankinglogistic)

    # Prox of the nonsmooth part
    proxH <- switch(penalty, elasticnet = prox.elasticnet, isotonic = prox.isotonic, boundednondecreasing = prox.boundednondecreasing, owl = prox.owl)

    o <- opts

    # If a non-penalized intercept is added, just add a constant column to x and modify the prox operator of the penalty to not touch the coefficient corresponding to the constant column. We also work on the centered matrix x to speed up convergence.
    if (intercept) {
        centered.x <- scale(x, scale = FALSE)
        o <- append(list(A=cbind(centered.x, rep(1,n))), o)
        myproxH <- proxH
    } else {
        o <- append(list(A=x), o)
        myproxH <- proxH
    }

    if (family=="survival") {
        o$comp <- surv_to_pairs(y[,1],y[,2])
    }
    o$b <- y
    o$lambda <- lambda

    if (penalty=="owl") {
        weights_owl <- seq(ncol(x), 0, -1)
        weights_owl <- weights_owl * o$lambda2
        weights_owl <- weights_owl + o$lambda1
        o$weights <- weights_owl
    }

    # Maximize the penalized log-likelihood
    res <- apg(gradG, myproxH, ncol(o$A), o)
    w <- res[["x"]]

    # Return the model (vector of weight in b, intercept in a0)
    if (intercept) {
        return(list(b=w[-length(w)], a0=w[length(w)] - sum(w[-length(w)] * attr(centered.x, "scaled:center"))))
    } else {
        return(list(b=w, a0=0))
    }
 }

where the prox of owl is:

prox.owl <- function(x, t=0, opts=list()) {
  v <- x

  w <- opts$weights

  v_abs <- abs(v)
  sorting <- sort(v_abs, decreasing = TRUE, index.return= TRUE)
  ix <- sorting$ix
  print(dim(v_abs))
  print(dim(w))
  v_abs <- v_abs[ix]
  v_abs <- pava(v_abs - w, decreasing = FALSE)
  v_abs[v_abs < 0] <- 0 

  # undo the sorting
  inv.ix <- phonTools::zeros(ix)
  inv.ix[ix] <- seq(length(v))
  v_abs <- v_abs[inv.ix]

  return (sign(v) * v_abs)
}

The call I made was: glm.apg(X_train, as.numeric(Y_train)-1, family = "binomial", penalty = "owl", opts = list(lambda1=0.0001, lambda2=0.01))

and I tried with several lambda1 and lambda2. Any other information that you may need just let me know.

Thanks in advance

jpvert commented 4 years ago

Hello, I do not see any obvious error, and won't have time to debug myself your code. I would recommend to check that the OWL prox is correct (with some examples where you know what the output should be). Then maybe print intermediate results during optimization to understand when and why it converges to zero. Maybe zero is the solution?

JoaoSantinha commented 4 years ago

Thanks for your answer. I was able to find a bug (decreasing = False) through a full implementation, but changing that part has not solve the issue here. I will close and post the solution when I get sometime to investigate this again :)