mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
710 stars 59 forks source link

`ndraws` in `add_[epred|linpred|predicted]_draws()` does not work #298

Closed JohannesNE closed 2 years ago

JohannesNE commented 2 years ago

When I set ndraws in add_[epred|linpred|predicted]_draws() I still get all draws.

library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.15.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#> 
#>     ar
library(tidybayes)
#> 
#> Attaching package: 'tidybayes'
#> The following objects are masked from 'package:brms':
#> 
#>     dstudent_t, pstudent_t, qstudent_t, rstudent_t

m_mpg_am = brm(
  mpg ~ log(hp) * am, 
  data = mtcars, 
  family = lognormal
)
#> Compiling Stan program...

mtcars |>
  add_epred_draws(m_mpg_am, ndraws = 1)
#> # A tibble: 128,000 × 16
#> # Groups:   mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, .row [32]
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb  .row
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
#>  1    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  2    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  3    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  4    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  5    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  6    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  7    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  8    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  9    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#> 10    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#> # … with 127,990 more rows, and 4 more variables: .chain <int>,
#> #   .iteration <int>, .draw <int>, .epred <dbl>

mtcars |>
  add_predicted_draws(m_mpg_am, ndraws = 1)
#> # A tibble: 128,000 × 16
#> # Groups:   mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, .row [32]
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb  .row
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
#>  1    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  2    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  3    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  4    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  5    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  6    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  7    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  8    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  9    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#> 10    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#> # … with 127,990 more rows, and 4 more variables: .chain <int>,
#> #   .iteration <int>, .draw <int>, .prediction <dbl>

mtcars |>
  add_linpred_draws(m_mpg_am, ndraws = 1)
#> # A tibble: 128,000 × 16
#> # Groups:   mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, .row [32]
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb  .row
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
#>  1    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  2    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  3    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  4    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  5    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  6    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  7    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  8    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#>  9    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#> 10    21     6   160   110   3.9  2.62  16.5     0     1     4     4     1
#> # … with 127,990 more rows, and 4 more variables: .chain <int>,
#> #   .iteration <int>, .draw <int>, .linpred <dbl>

sessionInfo()
#> R version 4.1.0 (2021-05-18)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Pop!_OS 21.10
#> 
#> Matrix products: default
#> BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/libmkl_rt.so
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_DK.UTF-8        LC_COLLATE=en_US.UTF-8    
#>  [5] LC_MONETARY=en_DK.UTF-8    LC_MESSAGES=en_US.UTF-8   
#>  [7] LC_PAPER=en_DK.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_DK.UTF-8 LC_IDENTIFICATION=C       
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] tidybayes_3.0.2.9000 brms_2.15.0          Rcpp_1.0.8          
#> 
#> loaded via a namespace (and not attached):
#>   [1] TH.data_1.0-10       minqa_1.2.4          colorspace_2.0-3    
#>   [4] ellipsis_0.3.2       ggridges_0.5.3       rsconnect_0.8.24    
#>   [7] estimability_1.3     markdown_1.1         base64enc_0.1-3     
#>  [10] fs_1.5.2             farver_2.1.0         rstan_2.21.3        
#>  [13] svUnit_1.0.6         DT_0.17              fansi_1.0.3         
#>  [16] mvtnorm_1.1-2        bridgesampling_1.1-2 codetools_0.2-18    
#>  [19] splines_4.1.0        knitr_1.37           shinythemes_1.2.0   
#>  [22] bayesplot_1.8.1      projpred_2.0.2       nloptr_1.2.2.2      
#>  [25] ggdist_3.1.0         shiny_1.7.1          compiler_4.1.0      
#>  [28] emmeans_1.6.2-1      backports_1.4.1      assertthat_0.2.1    
#>  [31] Matrix_1.3-4         fastmap_1.1.0        cli_3.2.0           
#>  [34] later_1.2.0          htmltools_0.5.2      prettyunits_1.1.1   
#>  [37] tools_4.1.0          igraph_1.2.6         coda_0.19-4         
#>  [40] gtable_0.3.0         glue_1.6.2           posterior_1.2.0     
#>  [43] reshape2_1.4.4       dplyr_1.0.8          styler_1.5.1.9000   
#>  [46] vctrs_0.4.0          nlme_3.1-152         crosstalk_1.1.1     
#>  [49] tensorA_0.36.2       xfun_0.30            stringr_1.4.0       
#>  [52] ps_1.6.0             lme4_1.1-27.1        mime_0.12           
#>  [55] miniUI_0.1.1.1       lifecycle_1.0.1      gtools_3.9.2        
#>  [58] MASS_7.3-54          zoo_1.8-9            scales_1.1.1        
#>  [61] colourpicker_1.1.0   promises_1.2.0.1     Brobdingnag_1.2-6   
#>  [64] sandwich_3.0-1       parallel_4.1.0       inline_0.3.19       
#>  [67] shinystan_2.5.0      gamm4_0.2-6          yaml_2.3.5          
#>  [70] gridExtra_2.3        ggplot2_3.3.5        loo_2.4.1           
#>  [73] StanHeaders_2.21.0-7 stringi_1.7.6        highr_0.9           
#>  [76] dygraphs_1.1.1.6     checkmate_2.0.0      boot_1.3-28         
#>  [79] pkgbuild_1.3.1       rlang_1.0.2          pkgconfig_2.0.3     
#>  [82] matrixStats_0.61.0   distributional_0.3.0 evaluate_0.15       
#>  [85] lattice_0.20-44      purrr_0.3.4          rstantools_2.1.1    
#>  [88] htmlwidgets_1.5.4    tidyselect_1.1.2     processx_3.5.2      
#>  [91] plyr_1.8.6           magrittr_2.0.3       R6_2.5.1            
#>  [94] generics_0.1.2       multcomp_1.4-17      DBI_1.1.1           
#>  [97] pillar_1.7.0         withr_2.5.0          mgcv_1.8-36         
#> [100] xts_0.12.1           survival_3.2-11      abind_1.4-5         
#> [103] tibble_3.1.6         crayon_1.5.1         arrayhelpers_1.1-0  
#> [106] utf8_1.2.2           rmarkdown_2.12.2     grid_4.1.0          
#> [109] callr_3.7.0          threejs_0.3.3        reprex_2.0.0        
#> [112] digest_0.6.29        xtable_1.8-4         tidyr_1.2.0         
#> [115] httpuv_1.6.2         RcppParallel_5.1.5   stats4_4.1.0        
#> [118] munsell_0.5.0        shinyjs_2.0.0

Created on 2022-04-07 by the reprex package (v2.0.0)

mjskay commented 2 years ago

Strange... it is working fine for me using both the CRAN and github versions of tidybayes, and you seem to be on the latest version of tidybayes. However, one thing I notice is that you are using brms 2.15 instead of the latest (2.16). If you update brms does it work?

JohannesNE commented 2 years ago

Hmm. It did not help updating brms to 2.16. It seems to be a problem with ndraws in epred_draws and nsamples in rstantools::posterior_epred. If I run

mtcars |>
  add_epred_draws(m_mpg_am, nsamples = 1)

I only get one draw (32 rows). I guess nsamples is simply passed on through ....

mjskay commented 2 years ago

Hmm that's bizarre, as the nsamples argument was renamed to ndraws in 2.16 so I don't think nsamples should work on 2.16. Are you sure the update was successful? Did you restart your R session?

JohannesNE commented 2 years ago

I did not restart the session :disappointed: It works! Thank you for the help, and I apologize for the inconvenience.

mjskay commented 2 years ago

No problem! Glad to help.