tidyverse / modelr

Helper functions for modelling
https://modelr.tidyverse.org
GNU General Public License v3.0
401 stars 66 forks source link

rsquare() yields NA with `crossv_loo()` #104

Closed sjspielman closed 4 years ago

sjspielman commented 4 years ago

Hi modelr team,

I'm encountering some unexpected behavior when using rsquare() from the output of of crossv_loo(), but behavior is as expected with crossv_kfold(). All r-squared values are NA from leave-one-out, but are appropriate doubles from k-fold.

Reprex as follows -

library(modelr)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(purrr)
my_iris_model <- function(input_data){
  lm(Sepal.Length ~ ., data = input_data)
}
### WORKS
crossv_kfold(iris, k = 3) %>%
  mutate(fitted_model = map(train, my_iris_model )) %>%
  mutate(rsquare = map2_dbl(fitted_model, test, modelr::rsquare)) 
#> Warning: `as_data_frame()` is deprecated as of tibble 2.0.0.
#> Please use `as_tibble()` instead.
#> The signature and semantics have changed, see `?as_tibble`.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_warnings()` to see where this warning was generated.
#> # A tibble: 3 x 5
#>   train        test         .id   fitted_model rsquare
#>   <named list> <named list> <chr> <named list>   <dbl>
#> 1 <resample>   <resample>   1     <lm>           0.867
#> 2 <resample>   <resample>   2     <lm>           0.855
#> 3 <resample>   <resample>   3     <lm>           0.838
### NA
crossv_loo(iris) %>%
  mutate(fitted_model = map(train, my_iris_model )) %>%
  mutate(rsquare = map2_dbl(fitted_model, test, modelr::rsquare)) 
#> # A tibble: 150 x 5
#>    train        test           .id fitted_model rsquare
#>    <named list> <named list> <int> <named list>   <dbl>
#>  1 <resample>   <resample>       1 <lm>              NA
#>  2 <resample>   <resample>       2 <lm>              NA
#>  3 <resample>   <resample>       3 <lm>              NA
#>  4 <resample>   <resample>       4 <lm>              NA
#>  5 <resample>   <resample>       5 <lm>              NA
#>  6 <resample>   <resample>       6 <lm>              NA
#>  7 <resample>   <resample>       7 <lm>              NA
#>  8 <resample>   <resample>       8 <lm>              NA
#>  9 <resample>   <resample>       9 <lm>              NA
#> 10 <resample>   <resample>      10 <lm>              NA
#> # … with 140 more rows

Thanks for any insights or, if necessary, bug fixes..

Best, Stephanie

PS as needed,

> sessionInfo()
R version 3.6.1 (2019-07-05)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Catalina 10.15.4

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods  
[7] base     

other attached packages:
 [1] reprex_0.3.0    ggforce_0.3.1   broom_0.5.5    
 [4] modelr_0.1.6    forcats_0.5.0   stringr_1.4.0  
 [7] dplyr_0.8.5     purrr_0.3.3     readr_1.3.1    
[10] tidyr_1.0.2     tibble_3.0.0    ggplot2_3.3.0  
[13] tidyverse_1.3.0

loaded via a namespace (and not attached):
 [1] xfun_0.12        tidyselect_1.0.0 haven_2.2.0     
 [4] lattice_0.20-41  colorspace_1.4-1 vctrs_0.2.4     
 [7] generics_0.0.2   htmltools_0.4.0  utf8_1.1.4      
[10] rlang_0.4.5      pillar_1.4.3     glue_1.4.0      
[13] withr_2.1.2      DBI_1.1.0        tweenr_1.0.1    
[16] dbplyr_1.4.2     readxl_1.3.1     lifecycle_0.2.0 
[19] munsell_0.5.0    gtable_0.3.0     cellranger_1.1.0
[22] rvest_0.3.5      evaluate_0.14    knitr_1.28      
[25] labeling_0.3     callr_3.4.3      ps_1.3.2        
[28] fansi_0.4.1      Rcpp_1.0.4       clipr_0.7.0     
[31] scales_1.1.0     backports_1.1.6  jsonlite_1.6.1  
[34] farver_2.0.3     fs_1.4.0         hms_0.5.3       
[37] digest_0.6.25    stringi_1.4.6    processx_3.4.2  
[40] polyclip_1.10-0  grid_3.6.1       cli_2.0.2       
[43] tools_3.6.1      magrittr_1.5     whisker_0.4     
[46] crayon_1.3.4     pkgconfig_2.0.3  ellipsis_0.3.0  
[49] MASS_7.3-51.5    xml2_1.3.0       lubridate_1.7.8 
[52] rmarkdown_2.1    assertthat_0.2.1 httr_1.4.1      
[55] rstudioapi_0.11  R6_2.4.1         nlme_3.1-145    
[58] compiler_3.6.1  
sjspielman commented 4 years ago

Ah nevermind this is obviously silly, can't calculate R^2 in LOO...