tidymodels / tidypredict

Run predictions inside the database
https://tidypredict.tidymodels.org
Other
257 stars 31 forks source link

Ranger produces list of trees #84

Open ndiquattro opened 4 years ago

ndiquattro commented 4 years ago

Hello, thanks for your work on this package, it is very exciting! I was trying to to follow the docs on using a ranger RF model, but it seems to return a list of trees/case_whens rather than one statement. Is it intended we execute all the trees on the DB then calculate the prediction from the results? I don't get that impression from the docs. Thanks!

library(ranger)
library(tidypredict)
library(dplyr, warn.conflicts = FALSE)

test_mod <- ranger(Species ~ ., iris, num.trees = 100)

trees <- tidypredict_fit(test_mod)

# Is list of trees
str(trees, max.level = 1, list.len = 3)
#> List of 100
#>  $ : language case_when(Petal.Width < 0.8 ~ "setosa", Sepal.Length < 5.75 & Petal.Width >=      0.8 ~ "versicolor", Petal.Width| __truncated__ ...
#>  $ : language case_when(Petal.Length < 2.45 ~ "setosa", Petal.Width >= 1.7 & Petal.Length >=      2.45 ~ "virginica", Petal.Len| __truncated__ ...
#>  $ : language case_when(Petal.Width < 0.8 ~ "setosa", Petal.Length < 4.9 & Petal.Width <      1.75 & Petal.Width >= 0.8 ~ "vers| __truncated__ ...
#>   [list output truncated]

# One example
trees[[1]]
#> case_when(Petal.Width < 0.8 ~ "setosa", Sepal.Length < 5.75 & 
#>     Petal.Width >= 0.8 ~ "versicolor", Petal.Width >= 1.75 & 
#>     Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "virginica", 
#>     Petal.Length < 4.75 & Sepal.Width < 2.25 & Petal.Width < 
#>         1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "versicolor", 
#>     Petal.Length >= 4.75 & Sepal.Width < 2.25 & Petal.Width < 
#>         1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "virginica", 
#>     Petal.Width < 1.55 & Sepal.Width >= 2.25 & Petal.Width < 
#>         1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "versicolor", 
#>     Petal.Width >= 1.65 & Petal.Width >= 1.55 & Sepal.Width >= 
#>         2.25 & Petal.Width < 1.75 & Sepal.Length >= 5.75 & Petal.Width >= 
#>         0.8 ~ "versicolor", Petal.Length < 5.45 & Petal.Width < 
#>         1.65 & Petal.Width >= 1.55 & Sepal.Width >= 2.25 & Petal.Width < 
#>         1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "versicolor", 
#>     Petal.Length >= 5.45 & Petal.Width < 1.65 & Petal.Width >= 
#>         1.55 & Sepal.Width >= 2.25 & Petal.Width < 1.75 & Sepal.Length >= 
#>         5.75 & Petal.Width >= 0.8 ~ "virginica")

# Suggested by old issue doesn't work
iris %>%
  tidypredict_to_column(test_mod)
#> Error in tidypredict_to_column(., test_mod): tidypredict_to_column does not support tree based models

Created on 2020-08-23 by the reprex package (v0.3.0)

topepo commented 3 years ago

I looked at the documentation and agree that it needs to be revised.

I think that the intention was to do some dplyr work to get the predictions in the format that you might want.

Here's some code that uses dplyr, purrr, and tidyr:

library(ranger)
library(tidypredict)
library(dplyr, warn.conflicts = FALSE)

test_mod <- ranger(Species ~ ., iris, num.trees = 100)

trees <- tidypredict_fit(test_mod)

new_samples <- iris[c(1, 51, 101), ]

votes <- 
 purrr:::map_dfr(trees, 
                 ~ tibble(.pred = rlang::eval_tidy(.x, new_samples),
                          .row = 1:nrow(new_samples)
                 )
 )

class_pred <-
 votes %>% 
 group_by(.row) %>% 
 count(.pred) %>% 
 slice_max(n) %>% 
 ungroup() %>% 
 select(-n)

class_pred
#> # A tibble: 3 x 2
#>    .row .pred     
#>   <int> <chr>     
#> 1     1 setosa    
#> 2     2 versicolor
#> 3     3 virginica

class_prob <- 
 votes %>% 
 group_by(.row) %>% 
 count(.pred) %>% 
 mutate(prob = n/100) %>% 
 ungroup() %>% 
 select(-n) %>% 
 tidyr::pivot_wider(id_cols = ".row", names_from = ".pred", values_from = "prob", values_fill = 0)

class_prob
#> # A tibble: 3 x 4
#>    .row setosa versicolor virginica
#>   <int>  <dbl>      <dbl>     <dbl>
#> 1     1      1       0         0   
#> 2     2      0       0.98      0.02
#> 3     3      0       0         1

Created on 2020-12-04 by the reprex package (v0.3.0)