topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.62k stars 632 forks source link

Predict on dummyVars cannot return a sparse matrix #671

Open jeffwong-nflx opened 7 years ago

jeffwong-nflx commented 7 years ago

Hi, I want to construct a dummyVars matrix in sparse format. Due to my factor variables having a large amount of levels a sparse matrix in the end could really help. Here is a reproducible example

require(data.table)
require(Matrix)
require(caret)

n = 1000
X = data.table(a = sample(c('a', 'b', 'c'), n, replace = T),
               b = sample(c('d', 'e', 'f'), n, replace = T),
               x = 1:n)
X$a = as.factor(X$a)
X$b = as.factor(X$b)

predict(foo <- dummyVars(~ a + b, X, sparse = TRUE), 
             X, sparse = TRUE)

I believe it would be a simple modification here https://github.com/topepo/caret/blob/master/pkg/caret/R/dummyVar.R#L227

This line generates a model.matrix, and we would simply need to allow passing sparse = TRUE and use sparse.model.matrix

jeffwong-nflx commented 7 years ago

I have attempted to write a solution for this in myfunc where I rely on sparse.model.matrix

require(Matrix)
require(caret)

myfunc <- function(object, newdata, na.action = na.pass, return_sparse = FALSE, ...) {
  if(is.null(newdata)) stop("newdata must be supplied")
  if(!is.data.frame(newdata)) newdata <- as.data.frame(newdata)
  if(!all(object$vars %in% names(newdata))) stop(
    paste("Variable(s)",
          paste("'", object$vars[object$vars %in% names(newdata)],
                "'", sep = "",
                collapse = ", "),
          "are not in newdata"))
  Terms <- object$terms
  Terms <- delete.response(Terms)
  if(!object$fullRank) {
    oldContr <- options("contrasts")$contrasts
    newContr <- oldContr
    newContr["unordered"] <- "contr.ltfr"
    options(contrasts = newContr)
    on.exit({
      options(contrasts = oldContr)
    })
  }
  m <- model.frame(Terms, newdata, na.action = na.action, xlev = object$lvls)

  if (return_sparse) {
    x = sparse.model.matrix(Terms, m)
  } else {
    x = model.matrix(Terms, m)  
  }

  if(object$levelsOnly) {
    for(i in object$facVars) {
      for(j in object$lvls[[i]]) {
        from_text <- paste0(i, j)
        colnames(x) <- gsub(from_text, j, colnames(x), fixed = TRUE)
      }
    }
  }
  if(!is.null(object$sep) & !object$levelsOnly) {
    for(i in object$facVars[order(-nchar(object$facVars))]) {
      for(j in object$lvls[[i]]) {
        from_text <- paste0(i, j)
        to_text <- paste(i, j, sep = object$sep)
        colnames(x) <- gsub(from_text, to_text, colnames(x), fixed = TRUE)
      }
    }
  }
  x[, colnames(x) != "(Intercept)", drop = FALSE]
}

n = 1000
X = data.frame(a = sample(c('a', 'b', 'c'), n, replace = T),
               b = sample(c('d', 'e', 'f'), n, replace = T),
               x = 1:n)
X$a = as.factor(X$a)
X$b = as.factor(X$b)

sparse.model.matrix(~ a, X)
foo <- dummyVars(~ ., X, sparse = TRUE, fullRank = FALSE)
bar <- myfunc(foo, X, return_sparse = TRUE, sparse = TRUE)

However I get this error

Error in model.spmatrix(t, data, transpose = transpose, drop.unused.levels = drop.unused.levels,  : 
  no slot of name "i" for this object of class "dgeMatrix"

I believe it may be related to switching the contrasts option to contr.ltfr, which may not be compatible with sparse.model.matrix? If I comment that block out the code will execute and return a sparse matrix, although it will return one where one level of each factor is dropped (R default)

topepo commented 7 years ago

I have the same (nebulous) issue. Would an object generated by model.Matrix work for you?

jeffwong-nflx commented 7 years ago

yes I believe that would work. did you have a workaround?

topepo commented 7 years ago
x <- model.matrix(Terms, m)

could easily be changed to

x <- if (sparse)
    model.matrix(Terms, m)
else
  sparse.model.matrix(Terms, m)

More testing would be needed though.

jeffwong-nflx commented 7 years ago

invoking model.Matrix from MatrixModels does not work either

topepo commented 7 years ago

Your example worked for me:

> require(data.table)
> require(Matrix)
> require(caret)
> 
> n = 1000
> X = data.table(a = sample(c('a', 'b', 'c'), n, replace = T),
+                b = sample(c('d', 'e', 'f'), n, replace = T),
+                x = 1:n)
> X$a = as.factor(X$a)
> X$b = as.factor(X$b)
> 
> predict(foo <- dummyVars(~ a + b, X, sparse = TRUE), 
+              head(X), sparse = TRUE)
  a.a a.b a.c b.d b.e b.f
1   1   0   0   0   0   1
2   1   0   0   1   0   0
3   0   1   0   0   0   1
4   0   1   0   1   0   0
5   0   1   0   0   1   0
6   0   1   0   0   0   1
> 
> library(sessioninfo)
> session_info()
─ Session info ───────────────────────────────────────────────────────────────────────────────────────────────────
 setting  value                       
 version  R version 3.3.3 (2017-03-06)
 os       macOS Sierra 10.12.6        
 system   x86_64, darwin13.4.0        
 ui       RStudio                     
 language (EN)                        
 collate  en_US.UTF-8                 
 tz       America/New_York            
 date     2017-08-16                  

─ Packages ───────────────────────────────────────────────────────────────────────────────────────────────────────
 package      * version  date       source                        
 assertthat     0.2.0    2017-04-11 CRAN (R 3.3.2)                
 bindr          0.1      2016-11-13 CRAN (R 3.3.2)                
 bindrcpp       0.2      2017-06-17 cran (@0.2)                   
 caret        * 6.0-77   2017-08-16 local (@6.0-77)               
 class          7.3-14   2015-08-30 CRAN (R 3.3.3)                
 clisymbols     1.2.0    2017-05-21 CRAN (R 3.3.2)                
 codetools      0.2-15   2016-10-05 CRAN (R 3.3.3)                
 colorspace     1.3-2    2016-12-14 CRAN (R 3.3.2)                
 CVST           0.2-1    2013-12-10 CRAN (R 3.3.0)                
 data.table   * 1.10.4   2017-02-01 CRAN (R 3.3.3)                
 ddalpha        1.2.1    2016-10-10 CRAN (R 3.3.0)                
 DEoptimR       1.0-8    2016-11-19 CRAN (R 3.3.2)                
 dimRed         0.1.0    2017-05-04 CRAN (R 3.3.2)                
 dplyr          0.7.2    2017-07-20 cran (@0.7.2)                 
 DRR            0.0.2    2016-09-15 CRAN (R 3.3.0)                
 foreach        1.4.3    2015-10-13 CRAN (R 3.3.0)                
 ggplot2      * 2.2.1    2016-12-30 CRAN (R 3.3.2)                
 glue           1.1.1    2017-06-21 CRAN (R 3.3.2)                
 gower          0.1.2    2017-02-23 CRAN (R 3.3.2)                
 gtable         0.2.0    2016-02-26 CRAN (R 3.3.0)                
 ipred          0.9-6    2017-03-01 cran (@0.9-6)                 
 iterators      1.0.8    2015-10-13 CRAN (R 3.3.0)                
 kernlab        0.9-25   2016-10-03 CRAN (R 3.3.0)                
 lattice      * 0.20-35  2017-03-25 CRAN (R 3.3.3)                
 lava           1.5      2017-03-16 cran (@1.5)                   
 lazyeval       0.2.0    2016-06-12 CRAN (R 3.3.0)                
 lubridate      1.6.0    2016-09-13 CRAN (R 3.3.0)                
 magrittr       1.5      2014-11-22 CRAN (R 3.3.0)                
 MASS           7.3-47   2017-04-21 CRAN (R 3.3.3)                
 Matrix       * 1.2-8    2017-01-20 CRAN (R 3.3.3)                
 ModelMetrics   1.1.0    2016-08-26 CRAN (R 3.3.0)                
 munsell        0.4.3    2016-02-13 CRAN (R 3.3.0)                
 nlme           3.1-131  2017-02-06 CRAN (R 3.3.3)                
 nnet           7.3-12   2016-02-02 CRAN (R 3.3.3)                
 pkgconfig      2.0.1    2017-03-21 cran (@2.0.1)                 
 plyr           1.8.4    2016-06-08 CRAN (R 3.3.0)                
 prodlim        1.6.1    2017-03-06 cran (@1.6.1)                 
 purrr          0.2.3    2017-08-02 cran (@0.2.3)                 
 R6             2.2.2    2017-06-17 cran (@2.2.2)                 
 Rcpp           0.12.12  2017-07-15 cran (@0.12.12)               
 RcppRoll       0.2.2    2015-04-05 CRAN (R 3.3.0)                
 recipes        0.1.0    2017-08-16 local (topepo/recipes@25a05ef)
 reshape2       1.4.2    2016-10-22 CRAN (R 3.3.3)                
 rlang          0.1.2    2017-08-09 cran (@0.1.2)                 
 robustbase     0.92-7   2016-12-09 CRAN (R 3.3.2)                
 rpart          4.1-11   2017-04-21 CRAN (R 3.3.3)                
 scales         0.4.1    2016-11-09 CRAN (R 3.3.2)                
 sessioninfo  * 1.0.0    2017-06-21 CRAN (R 3.3.2)                
 stringi        1.1.5    2017-04-07 CRAN (R 3.3.2)                
 stringr        1.2.0    2017-02-18 CRAN (R 3.3.2)                
 survival       2.40-1   2016-10-30 CRAN (R 3.3.3)                
 tibble         1.3.3    2017-05-28 CRAN (R 3.3.2)                
 timeDate       3012.100 2015-01-23 cran (@3012.10)               
 withr          2.0.0    2017-07-28 CRAN (R 3.3.2) 

Give the code that I'm about to check in a try

jeffwong-nflx commented 7 years ago

You have a bug here https://github.com/topepo/caret/commit/20a5f5d6d478652862c854e464649b6b4bdeed7f#diff-f140abde89f5d74a91c27104c9baa51bR227

When sparse is true, it uses model.matrix, not sparse.model.matrix. The output of the example is dense

jeffwong-nflx commented 7 years ago

Bumping this issue :)