alexpghayes / safepredict

Consistent prediction following tidymodels principles
https://alexpghayes.github.io/safepredict/
Other
17 stars 1 forks source link

safepredict

Travis build
status Coverage
status lifecycle

safepredict has two goals: to provide a consistent interface to prediction via the safe_predict() generic, and to accurately quantify prediction uncertainty.

safe_predict():

safepredict follows the tidymodels prediction specification.

Installation

safepredict is currently in the beginning stages of development and is available only on Github. You can install it with:

# install.packages("devtools")
devtools::install_github("alexpghayes/safepredict")

Arguments

The three main arguments to safe_predict() are always the same:

Examples

Suppose you fit a logistic regression using glm:

library(tibble)

data <- tibble(
  y = as.factor(rep(c("A", "B"), each = 50)),
  x = c(rnorm(50, 1), rnorm(50, 3))
)

fit <- glm(y ~ x, data, family = binomial)

You can predict class probabilities:

library(safepredict)

test <- tibble(x = rnorm(10, 2))

safe_predict(fit, new_data = test, type = "prob")
#> # A tibble: 10 x 2
#>   .pred_A .pred_B
#>     <dbl>   <dbl>
#> 1  0.333    0.667
#> 2  0.0410   0.959
#> 3  0.619    0.381
#> 4  0.467    0.533
#> 5  0.132    0.868
#> # ... with 5 more rows

or can jump straight to hard class decisions

safe_predict(fit, new_data = test, type = "class")
#> # A tibble: 10 x 1
#>   .pred_class
#>   <fct>      
#> 1 B          
#> 2 B          
#> 3 A          
#> 4 B          
#> 5 B          
#> # ... with 5 more rows

We can also get predictions on the link scale:

safe_predict(fit, new_data = test, type = "link")
#> # A tibble: 10 x 1
#>    .pred
#>    <dbl>
#> 1  0.696
#> 2  3.15 
#> 3 -0.485
#> 4  0.132
#> 5  1.88 
#> # ... with 5 more rows

or we can get confidence intervals on the response scale

safe_predict(fit, new_data = test, type = "conf_int")
#> # A tibble: 10 x 3
#>   .pred .pred_lower .pred_upper
#>   <dbl>       <dbl>       <dbl>
#> 1 0.667       0.795       0.510
#> 2 0.959       0.989       0.862
#> 3 0.381       0.545       0.240
#> 4 0.533       0.680       0.380
#> 5 0.868       0.943       0.724
#> # ... with 5 more rows

Related work