jolars / sgdnet

Fast Sparse Linear Models for Big Data with SAGA
https://jolars.github.io/sgdnet/
GNU General Public License v3.0
5 stars 2 forks source link

Erroneous results with non-standardized features #12

Open jolars opened 6 years ago

jolars commented 6 years ago

The package is producing strange results when features aren't standardized.

library(glmnet)
library(sgdnet)

set.seed(1)

x <- with(trees, cbind(Girth, Height))
y <- trees$Volume

sfit <- sgdnet(x, y, standardize = T)
gfit <- glmnet(x, y, standardize = T)

coef(sfit, s = 0)
# 3 x 1 sparse Matrix of class "dgCMatrix"
# 1
# (Intercept) -57.9724027
# Girth         4.7078901
# Height        0.3390993
coef(gfit, s = 0)
# 3 x 1 sparse Matrix of class "dgCMatrix"
# 1
# (Intercept) -56.9741742
# Girth         4.6882466
# Height        0.3293873

sfit <- sgdnet(x, y, standardize = F)
gfit <- glmnet(x, y, standardize = F)

coef(sfit, s = 0)
# 3 x 1 sparse Matrix of class "dgCMatrix"
# 1
# (Intercept) 27.7431997
# Girth        4.7665048
# Height      -0.7912432
coef(gfit, s = 0)
# 3 x 1 sparse Matrix of class "dgCMatrix"
# 1
# (Intercept) -57.5678250
# Girth         4.6724709
# Height        0.3399486

# for reference
coef(lm(Volume ~ ., data = trees))
# (Intercept)       Girth      Height 
# -57.9876589   4.7081605   0.3392512 

As you can see, the coefficients are way off. The issue seems to be related to the absence of centering. I have gone over the rescaling multiple times (https://github.com/jolars/sgdnet/blob/5f22c994b4e71796feb2440e98000265f03eadab/src/utils.h#L442-L472) but i haven't found any issues there so I am guessing that there's something wrong inside the algorithm.

Do you have any clues of what's going on @michaelweylandt @tdhock ?

Interestingly, the scikit-learn implementation unconditionally centers its features (except for the sparse implementation).

jolars commented 6 years ago

If I am reading this correctly, it seems as if glmnet always centers the predictors when the intercept is fit and never when it is not.

(https://github.com/cran/glmnet/blob/ff79c8962562e278abfaafb98641a7aecf140b23/inst/mortran/glmnet5dp.m#L766-L800)

intr is the intercept flag.

Could you please check that I am correct? It's not the most user-friendly code around.

michaelweylandt commented 6 years ago

Ha - no, it's not the easiest to read. You are correct that glmnet doesn't treat intercepts and standardization orthogonally (I'd argue it should), but I don't think that's relevant here. The glmnet intercept looks approximately right (taking the lm result as truth), while yours is off:

I don't quite follow what is happening here:

https://github.com/jolars/sgdnet/blob/5f22c994b4e71796feb2440e98000265f03eadab/src/saga.h#L433

It looks like this code:

https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/sag_fast.pyx#L442

but they scale by the number of samples seen so far, not the total number of samples. Not sure if that's super important though.

jolars commented 6 years ago

Right. I never really understood why they did that. I figured it was something related to stabilizing the solution in the early inner loops but it didn't seem to make any real difference when I compared with or without (which was a while back) so I dropped it.

michaelweylandt commented 6 years ago

Haven't figured it out yet, but here's a smaller reproducible example which suggests it's possibly somewhere in the centering calculations and not the scaling:

x <-  structure(c(0, 1, 0, 0, 0, 1), .Dim = 3:2)
y <- c(3, 5, 7)

giving (correctly)

 > set.seed(1); lm(y ~ x); sgdnet(x, y, lambda=0, standardize=FALSE, thresh=1e-10)$a0

 Call:
 lm(formula = y ~ x)

 Coefficients:
 (Intercept)           x1           x2
           3            2            4

 int_old = -1.22474
 x_scale_prod = 0
 y_center = 5
 y_scale = 1.63299
 int_new = 3
 s0
  3

Now add scaling:

> set.seed(1); lm(y ~ I(2 * x)); sgdnet(x * 2, y, lambda=0, standardize=FALSE, thresh=1e-10)$a0                                                                                                            

 Call:                                                                                                                                                                                                      
 lm(formula = y ~ I(2 * x))                                                                                                                                                                                 

 Coefficients:                                                                                                                                                                                              
 (Intercept)    I(2 * x)1    I(2 * x)2                                                                                                                                                                      
           3            1            2                                                                                                                                                                      

 int_old = -1.22474                                                                                                                                                                                         
 x_scale_prod = 0                                                                                                                                                                                           
 y_center = 5                                                                                                                                                                                               
 y_scale = 1.63299                                                                                                                                                                                          
 int_new = 3                                                                                                                                                                                                
 s0                                                                                                                                                                                                         
  3

Now add a need for centering and things get weird:

!> set.seed(1); lm(y ~ I(2 + x)); sgdnet(x + 2, y, lambda=0, standardize=FALSE, thresh=1e-10)$a0                                                                                                            

 Call:                                                                                                                                                                                                      
 lm(formula = y ~ I(2 + x))                                                                                                                                                                                 

 Coefficients:                                                                                                                                                                                              
 (Intercept)    I(2 + x)1    I(2 + x)2                                                                                                                                                                      
          -9            2            4                                                                                                                                                                      

 int_old = -5.40318                                                                                                                                                                                         
 x_scale_prod = 0                                                                                                                                                                                           
 y_center = 5                                                                                                                                                                                               
 y_scale = 1.63299                                                                                                                                                                                          
 int_new = -3.82336                                                                                                                                                                                         
        s0                                                                                                                                                                                                  
 -3.823363 

The int_old (raw scale intercept) from the working cases appears to be essentially:

> lm((y - mean(y))/sd2(y) ~ apply(x, 2, function(x) x/sd2(x)))                                                                                                                                             

 Call:                                                                                                                                                                                                      
 lm(formula = (y - mean(y))/sd2(y) ~ apply(x, 2, function(x) x/sd2(x)))                                                                                                                                     

 Coefficients:                                                                                                                                                                                              
                        (Intercept)  apply(x, 2, function(x) x/sd2(x))1                                                                                                                                     
                            -1.2247                              0.5774                                                                                                                                     
 apply(x, 2, function(x) x/sd2(x))2                                                                                                                                                                         
                             1.1547  

but I don't get int_old = -5.4 when I do the same calculation with x + 2 on the RHS.

jolars commented 6 years ago

Hm I now think this was something of a false alarm.

Your example works quite alright if the entire path is fit

x <-  structure(c(0, 1, 0, 0, 0, 1), .Dim = 3:2)
y <- c(3, 5, 7)
set.seed(1)
lm(y ~ I(2 + x))
f <- sgdnet(x + 2, y, standardize=FALSE, thresh=1e-10)
coef(f, s = 0)
# 3 x 1 sparse Matrix of class "dgCMatrix"
# 1
# (Intercept) -8.997029
# V1           1.999364
# V2           3.999364

Hm, perhaps this is related to the step sizes when standardization is off. This what i looks like for lasso least squares with and without standardization

set.seed(1)

x <- with(trees, cbind(Girth, Height))

L <- (max(rowSums(x^2)) + 1)
1/(2*L)
#> [1] 6.254409e-05

xs <- scale(x, scale = FALSE)

L <- (max(rowSums(xs^2)) + 1)
1/(2*L)
#> [1] 0.002634516

But I'm not sure if that's enough to actually make the algorithm stall or something. Lowering the tolerance doesn't seem to do anything in this example. (edit: actually, I was wrong. see below)

jolars commented 6 years ago

It appears that the scikit-learn implementation suffer the same issue. (I had to change to logistic regression since Ridge() in scikit-learn always centers the features)

> set.seed(1)
> library(reticulate)
> sk <- import("sklearn")
> 
> x <- with(infert, cbind(age, parity))
> y <- infert$case
> 
> mod <- sk$linear_model$LogisticRegression(penalty = "l2",
+                                           C = 1e10,
+                                           max_iter = 1000,
+                                           tol = 1e-3,
+                                           solver = "saga")
> mod$fit(x, y)
LogisticRegression(C=10000000000.0, class_weight=None, dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=1000.0,
          multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
          solver='saga', tol=0.001, verbose=0, warm_start=False)
> cbind(mod$coef_, mod$intercept_)
            [,1]        [,2]       [,3]
[1,] -0.00599717 0.008634496 -0.5101588
> 
> f <- sgdnet(x, y, family = "binomial", standardize = FALSE, lambda = 0)
> coef(f)
3 x 1 sparse Matrix of class "dgCMatrix"
                      s0
(Intercept) -0.503653435
age         -0.006188175
parity       0.008475562
> 
> g <- glm(y ~ x, family = "binomial")
> coef(g)
 (Intercept)         xage      xparity 
-0.753682743  0.001137342  0.014661742 

Setting tolerance low enough and upping the maximum number of iterations, however, we eventually get there (with sgdnet too)

> mod$set_params(tol = 1e-8, max_iter = 1e5)
LogisticRegression(C=10000000000.0, class_weight=None, dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100000.0,
          multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
          solver='saga', tol=1e-08, verbose=0, warm_start=False)
> mod$fit(x, y)
LogisticRegression(C=10000000000.0, class_weight=None, dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100000.0,
          multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
          solver='saga', tol=1e-08, verbose=0, warm_start=False)
> cbind(mod$intercept_, mod$coef_)
           [,1]        [,2]       [,3]
[1,] -0.7536788 0.001137225 0.01466164
> 
> coef(sgdnet(x, y, family = "binomial", 
+             standardize = FALSE, lambda = 0, maxit = 1e5, thresh = 1e-8))
3 x 1 sparse Matrix of class "dgCMatrix"
                      s0
(Intercept) -0.753678389
age          0.001137214
parity       0.014661634

And actually, I was wrong before, the trees example does converge in the end.

> x <- with(trees, cbind(Girth, Height))
> y <- trees$Volume
> 
> coef(sgdnet(x, y, standardize = FALSE, thresh = 1e-15, maxit = 1e9, lambda = 0))
3 x 1 sparse Matrix of class "dgCMatrix"
                     s0
(Intercept) -57.9876589
Girth         4.7081605
Height        0.3392512

So to summarize, I don't think this is a bug but only just an issue with slow convergence when variables are not standardized. There are a few options as I see it:

  1. We could force standardization, at least centering, possibly by running a check to see that the input is standardized (if standardize = FALSE) and otherwise do it for the user.
  2. Lower tolerance and increase max iterations with the side effect that we would run the algorithm longer than strictly necessary.
  3. Revisit stopping criteria. I know that lightning does it differently
  4. Do nothing and let the user accept the responsibility of taking care of standardization prior to running the algorithm when standardize = FALSE.

I don't really see the issue with going route 1, even if it heavy-handed.

michaelweylandt commented 6 years ago

I don't like 1 -- if the user supplies standardize=FALSE, I think we should respect that.

That said, we only need to give the appearance of respecting it: we could standardize internally and adjust the penalty weights so that the net effect is the same as not standardizing. Would this work?

jolars commented 6 years ago

Hm, yeah. I'll have to fiddle with it a bit, and it would mean that we'd have to keep a vector of these scales around, but I don't at this moment see why it shouldn't work. And as far as centering is concerned, this is really what for instance glmnet is already doing, i.e. centering its predictors despite the user specifying standardize = FALSE.