curso-r / treesnip

Parsnip backends for `tree`, `lightGBM` and `Catboost`
https://curso-r.github.io/treesnip
GNU General Public License v3.0
85 stars 13 forks source link

catboost and predict threshold #58

Open pecto2020 opened 2 years ago

pecto2020 commented 2 years ago

Hi, predict(catboost) in tidymodels doesn't use the default threshold of 0.5 but something else. Does catboost use a class_weight during the training process? In that case how do I change it in tidymodels/treesnip? I attach a comparison between catboost and random forest. Thanks

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(mlbench)
library(catboost)
library(treesnip)
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")

#load data
data(PimaIndiansDiabetes)
diabetes_orig<-PimaIndiansDiabetes

#set random seed
set.seed(123)
#create initial split
diabetes_split <- initial_split(diabetes_orig, prop = 3/4)
diabetes_split
#> <Analysis/Assess/Total>
#> <576/192/768>
#create training set
diabetes_train <- training(diabetes_split)
#create test set
diabetes_test <- testing(diabetes_split)

#train Random Forest

# model specification
trees_spec<-rand_forest()%>%
  set_mode("classification") %>%
  set_engine("ranger")

# fit on training data
trees_fit<-trees_spec %>% fit(diabetes~., data=diabetes_train)

# predict
trees_pred<-predict(trees_fit, diabetes_test)%>%
  bind_cols(predict(trees_fit,diabetes_test, type="prob"))%>%
  bind_cols(diabetes_test%>% select(diabetes)) 
# get metrics
trees_perf<- trees_pred %>%
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(trees_pred %>% sens(trut = diabetes, .pred_class, event_levels="second"))

# change threshold
trees_05<-trees_pred %>% 
  mutate(
    .pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
  mutate_if(is.character, as.factor)
# get metrics
trees_perf_05<-trees_05%>% 
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows( trees_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))

trees_perf
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.823
#> 2 sens    binary         0.856
trees_perf_05
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.823
#> 2 sens    binary         0.856

#train Catboost

# model specification
catboost_spec<-(boost_tree(tree_depth=10) %>% 
                  set_mode("classification") %>%
                  set_engine("catboost", nthread=4))
# fit on training data
catboost_fit<-catboost_spec %>% fit(diabetes~., data=diabetes_train)

# predict
catboost_pred<-predict(catboost_fit, diabetes_test) %>%
  bind_cols(predict(catboost_fit,diabetes_test, type="prob"))%>%
  bind_cols(diabetes_test%>% select(diabetes)) 

# get metrics
catboost_perf<- catboost_pred %>%
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(catboost_pred %>% sens(truth = diabetes, .pred_class, event_levels="second"))

#  change threshold
catboost_05<-catboost_pred %>% 
  mutate(
    .pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
  mutate_if(is.character, as.factor)
# get metrics
catboost_perf_05<-catboost_05%>% 
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(catboost_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))

catboost_perf
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.801
#> 2 sens    binary         1
catboost_perf_05
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.801
#> 2 sens    binary         0.992

Created on 2022-02-02 by the reprex package (v2.0.1)

pecto2020 commented 2 years ago

Notably, using catboost with caret seems to work

library(mlbench)
library(catboost)
library(caret)
#> Loading required package: ggplot2
#> Loading required package: lattice
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

data(PimaIndiansDiabetes)
diabetes_orig<-PimaIndiansDiabetes

#set random seed
set.seed(123)
#create initial split
diabetes_split <- initial_split(diabetes_orig, prop = 3/4)
diabetes_split
#> <Analysis/Assess/Total>
#> <576/192/768>
#create training set
diabetes_train <- training(diabetes_split)
#create test set
diabetes_test <- testing(diabetes_split)

fitControl <- trainControl(method = "cv",
                             number = 3,
                             savePredictions = TRUE,
                             summaryFunction = twoClassSummary,
                             classProbs = TRUE)

model <- train(x = diabetes_train %>% select(-diabetes),
               y = diabetes_train$diabetes,
               method = catboost.caret, 
               trControl = fitControl, 
               tuneLength = 3,
               metric = "ROC")

preds1<-predict(model, diabetes_test) %>% as_tibble() %>% mutate(.pred_class = value, .keep="unused") %>%
    bind_cols(predict(model,diabetes_test, type="prob")) %>%
    bind_cols(diabetes_test %>% select(diabetes))

preds1%>% roc_auc(truth = diabetes,pos, event_level="second") %>%
    bind_rows( preds1 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.821
#> 2 sens    binary         0.848

preds1_05<-preds1 %>% mutate(
      .pred_class = ifelse(pos>0.5,"pos","neg"))%>%
      mutate_if(is.character, as.factor)

    preds1_05%>% roc_auc(truth = diabetes,pos, event_level="second") %>%
      bind_rows( preds1_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.821
#> 2 sens    binary         0.848
Created on 2022-02-02 by the reprex package (v2.0.1)