tidyverse / modelr

Helper functions for modelling
https://modelr.tidyverse.org
GNU General Public License v3.0
401 stars 65 forks source link

Feature request: pass the output from fit_with to gather_predictions #29

Closed Tutuchan closed 6 years ago

Tutuchan commented 8 years ago

Unless I'm mistaken, gather_predictions does not interact well with the output from fit_with:

mtcars_formulas <- formulas(
  ~disp,
  additive = ~drat + cyl,
  interaction = ~drat * cyl,
  full = add_predictors(interaction, ~am, ~vs)
)

models <- mtcars %>% 
  fit_with(lm, mtcars_formulas)
> models

$additive

Call:
.f(formula = disp ~ drat + cyl, data = data)

Coefficients:
(Intercept)         drat          cyl  
      18.72       -35.83        55.09  

$interaction

Call:
.f(formula = disp ~ drat * cyl, data = data)

Coefficients:
(Intercept)         drat          cyl     drat:cyl  
   -193.487       20.390       88.675       -9.155  

$full

Call:
.f(formula = disp ~ drat * cyl + am + vs, data = data)

Coefficients:
(Intercept)         drat          cyl           am           vs     drat:cyl  
    -164.83        27.58        75.84       -38.22       -15.53        -6.97 
# This works
mtcars %>% 
  gather_predictions(additive = models$additive, full = models$full)
# This doesnt
mtcars %>% 
  gather_predictions(models)

I believe this is due to the models <- tibble::lst(...) in gather_predictions. I tried using do.call but I don't think there is an easy way to make this work with the current code.

One possible solution would be to add a class to the list output from fit_with and check in gather_predictions if the input inherits from this class and proceed accordingly.

Here is a tentative code but it's not very elegant and I'm pretty sure there is a better way:

gather_predictions <- function(data, ..., .pred = "pred", .model = "model") {
  models <- tibble::lst(...)
  # Let's suppose the output from fit_with has the class fit_with (for example)
  if (inherits(models[[1]], "fit_with")) models <- models[[1]]

  df <- purrr::map2(models, .pred, add_predictions, data = data)
  names(df) <- names(models)
  dplyr::bind_rows(df, .id = .model)
}
> mtcars %>% 
gather_predictions(models)

         model  mpg cyl  disp  hp drat    wt  qsec vs am gear carb      pred
1     additive 21.0   6 160.0 110 3.90 2.620 16.46  0  1    4    4 209.52005
2     additive 21.0   6 160.0 110 3.90 2.875 17.02  0  1    4    4 209.52005
3     additive 22.8   4 108.0  93 3.85 2.320 18.61  1  1    4    1 101.13041
4     additive 21.4   6 258.0 110 3.08 3.215 19.44  1  0    3    1 238.90112
5     additive 18.7   8 360.0 175 3.15 3.440 17.02  0  0    3    2 346.57415
6     additive 18.1   6 225.0 105 2.76 3.460 20.22  1  0    3    1 250.36691
7     additive 14.3   8 360.0 245 3.21 3.570 15.84  0  0    3    4 344.42432
8     additive 24.4   4 146.7  62 3.69 3.190 20.00  1  0    4    2 106.86330
9     additive 22.8   4 140.8  95 3.92 3.150 22.90  1  0    4    2  98.62227
10    additive 19.2   6 167.6 123 3.92 3.440 18.30  1  0    4    4 208.80344
...

Once again, there may be something very obvious that I didn't see. Any thoughts ?

hadley commented 6 years ago

You need to use !!! here:

library(modelr)
mtcars_formulas <- formulas(
  ~disp,
  additive = ~drat + cyl,
  interaction = ~drat * cyl,
  full = add_predictors(interaction, ~am, ~vs)
)

models <- mtcars %>% 
  fit_with(lm, mtcars_formulas)

mtcars %>% 
  gather_predictions(!!!models) %>%
  tibble::as_tibble()
#> # A tibble: 96 x 13
#>    model   mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1 addi…  21       6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2 addi…  21       6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3 addi…  22.8     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4 addi…  21.4     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5 addi…  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6 addi…  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7 addi…  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8 addi…  24.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9 addi…  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10 addi…  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
#> # ... with 86 more rows, and 1 more variable: pred <dbl>

Created on 2018-05-10 by the reprex package (v0.2.0).

This is a general pattern that we haven't yet figured out how to document, but basically anywhere in the tidyverse that normally takes individual arguments, you can always supply !!!list(x, y, z)