topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

Change the time measurement method for final model fitting #1283

Closed GFabien closed 2 years ago

GFabien commented 2 years ago

As described in #1277, I was surprised that training a linear regression model was much slower with caret than directly calling the function stats::lm. Diving deeper into this issue, the reason for that is the use of the function system.time to measure the execution time for training the final model. Replacing this with two calls to proc.time seems to solve the issue. You can find two benchmarks (with and without the changes) below.

library(caret)
library(microbenchmark)
data("iris")

X <- iris[, -5]
y <- iris[, 4]

lm_method <- getModelInfo("lm", regex = FALSE)[[1]]

base_lm <- function() {
  stats::lm(Petal.Width ~ ., data = X)
}
caret_lm <- function() {
  caret::train(Petal.Width ~ .,
               data = X,
               method = "lm",
               trControl = caret::trainControl(method = "none")
  )
}
caret_lm2 <- function() {
  caret::train(Petal.Width ~ .,
               data = X,
               method = lm_method,
               trControl = caret::trainControl(method = "none")
  )
}

Old version

res <- microbenchmark(NULL, base_lm(), caret_lm(), caret_lm2(), times = 50L)
print(res, unit = "ms")
#> Unit: milliseconds
#>         expr        min         lq         mean      median         uq
#>         NULL   0.000010   0.000013   0.00003208   0.0000230   0.000050
#>    base_lm()   0.896871   0.942088   1.01018128   0.9657385   0.984955
#>   caret_lm() 190.335801 192.603855 194.10944456 193.8592495 195.120467
#>  caret_lm2() 164.702259 166.670015 176.11732996 167.5912635 169.138243
#>         max neval
#>    0.000093    50
#>    2.731153    50
#>  199.377730    50
#>  565.590524    50

New version

res <- microbenchmark(NULL, base_lm(), caret_lm(), caret_lm2(), times = 50L)
print(res, unit = "ms")
#> Unit: milliseconds
#>         expr       min        lq        mean     median        uq        max
#>         NULL  0.000000  0.000002  0.00001196  0.0000090  0.000022   0.000035
#>    base_lm()  0.534286  0.604855  0.66191388  0.6302575  0.676265   1.826794
#>   caret_lm() 21.805373 24.843148 33.86266576 25.9300345 27.759111 401.051678
#>  caret_lm2()  2.564492  2.944808  3.40049146  3.1206245  3.303301  11.708001
#>  neval
#>     50
#>     50
#>     50
#>     50

The only difference between caret_lm and caret_lm2 is that caret_lm2 skips the call to getModelInfo.

topepo commented 2 years ago

Thanks for the contribution and sorry for the long wait.

system.time() is generally preferred but I don't think it is a big deal to use proc.time(). I'll update and merge 👍