tidymodels / workflows

Modeling Workflows
https://workflows.tidymodels.org/
Other
205 stars 23 forks source link

Internal Error when tuning catboost parameters #220

Closed Milardkh closed 7 months ago

Milardkh commented 7 months ago

I'm having trouble running tune_grid() for catboost using tidymodels, though it worked for the earlier versions of catboost as I searched in kaggle. I also tried different tunning process functions but the error stays the same. The result of show_notes(.Last.tune.result) is as follows: unique notes: ───────────────────────────────────────────────────────────────────────────────────────────────────────── Error in pull_workflow_spec_encoding_tbl(): ! Exactly 1 model/engine/mode combination must be located. ℹ This is an internal error that was detected in the workflows package. Please report it at https://github.com/tidymodels/workflows/issues with a reprex (https://tidyverse.org/help/) and the full backtrace.

# Importing libraries
library(tidyverse)
#> Warning: package 'ggplot2' was built under R version 4.3.2
#> Warning: package 'tidyr' was built under R version 4.3.2
#> Warning: package 'readr' was built under R version 4.3.2
#> Warning: package 'dplyr' was built under R version 4.3.2
#> Warning: package 'stringr' was built under R version 4.3.2
#> Warning: package 'lubridate' was built under R version 4.3.2
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.3.2
#> Warning: package 'dials' was built under R version 4.3.2
#> Warning: package 'scales' was built under R version 4.3.2
#> Warning: package 'infer' was built under R version 4.3.2
#> Warning: package 'modeldata' was built under R version 4.3.2
#> Warning: package 'parsnip' was built under R version 4.3.2
#> Warning: package 'rsample' was built under R version 4.3.2
#> Warning: package 'tune' was built under R version 4.3.2
#> Warning: package 'workflowsets' was built under R version 4.3.2
#> Warning: package 'yardstick' was built under R version 4.3.2
library(reprex)
#> Warning: package 'reprex' was built under R version 4.3.2
library(readxl)
library(writexl)
#> Warning: package 'writexl' was built under R version 4.3.2
library(doParallel)
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: iterators
#> Loading required package: parallel

## Parallel processing setup
cores <- parallel::detectCores(logical = FALSE)
cores
#> [1] 4

cl <- makePSOCKcluster(names = cores)
cl
#> socket cluster with 4 nodes on host 'localhost'
registerDoParallel(cl = cl)

df_total_work <- excel_sheets(path = paste0("C:/Users/Mashadservice.com/Desktop/Project3/Miners/",
                                            "Data gathered/استاد خانی/df_total_work.xlsx")) %>% 
  map(.f = ~read_xlsx(path = paste0("C:/Users/Mashadservice.com/Desktop/Project3/Miners/",
                                    "Data gathered/استاد خانی/df_total_work.xlsx"), sheet = .))
df_total_work <- df_total_work[[1]]

class(df_total_work)
#> [1] "tbl_df"     "tbl"        "data.frame"
dim(df_total_work)
#> [1] 2018   42
glimpse(x = df_total_work)
#> Rows: 2,018
#> Columns: 42
#> $ MeterNumber                   <chr> "1729", "1729", "1729", "1729", "1729", …
#> $ DateShamsi                    <chr> "1401-02-14 00:00:00", "1401-02-15 00:00…
#> $ Date                          <dttm> 2022-05-04, 2022-05-05, 2022-05-06, 202…
#> $ Daily                         <dttm> 2022-05-04, 2022-05-05, 2022-05-06, 202…
#> $ ActiveEnergyPosTotal          <dbl> 985371.4, 985454.1, 985570.4, 985684.1, …
#> $ ActiveEnergyPostariff1        <dbl> 539273.6, 539326.7, 539400.7, 539480.1, …
#> $ ActiveEnergyPostariff2        <dbl> 149320.4, 149339.2, 149355.4, 149376.1, …
#> $ ActiveEnergyPostariff3        <dbl> 296777.3, 296788.2, 296814.4, 296827.9, …
#> $ ActiveEnergyNegTotal          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff1        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff2        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff3        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ReactiveEnergyPosTotal        <dbl> 41099.10, 41100.06, 41105.70, 41108.70, …
#> $ ReactiveEnergyPostariff1      <dbl> 19902.96, 19903.68, 19903.68, 19905.66, …
#> $ ReactiveEnergyPostariff2      <dbl> 9277.56, 9277.56, 9277.56, 9278.34, 9279…
#> $ ReactiveEnergyPostariff3      <dbl> 11918.52, 11918.82, 11924.46, 11924.70, …
#> $ ReactiveEnergyNegTotal        <dbl> 57187.56, 57215.22, 57265.44, 57304.68, …
#> $ ReactiveEnergyNegtariff1      <dbl> 32085.66, 32104.50, 32139.42, 32169.06, …
#> $ ReactiveEnergyNegtariff2      <dbl> 9429.90, 9438.42, 9445.62, 9453.24, 9460…
#> $ ReactiveEnergyNegtariff3      <dbl> 15672.00, 15672.24, 15680.34, 15682.26, …
#> $ Tarrif_category               <chr> "Commercial", "Commercial", "Commercial"…
#> $ Contracted_demand             <dbl> 180, 180, 180, 180, 180, 180, 180, 180, …
#> $ Number_of_phases              <chr> "3", "3", "3", "3", "3", "3", "3", "3", …
#> $ Target                        <chr> "Mining", "Mining", "Mining", "Mining", …
#> $ AvgPFt1                       <dbl> 0.5, 2.0, 0.1, 0.6, 2.0, 3.7, 3.0, 4.0, …
#> $ AvgPFt2                       <dbl> 95.3, 96.9, 93.6, 95.7, 95.6, 94.7, 94.7…
#> $ AvgPFt3                       <dbl> 61.0, 57.1, 71.3, 64.1, 71.6, 60.0, 61.2…
#> $ Lag1_ActiveEnergyPosTotal     <dbl> NA, 82.74, 116.28, 113.70, 68.04, 95.04,…
#> $ Lag1_ActiveEnergyPostariff1   <dbl> NA, 53.16, 73.92, 79.44, 38.82, 60.36, 4…
#> $ Lag1_ActiveEnergyPostariff2   <dbl> NA, 18.72, 16.20, 20.70, 21.30, 19.98, 1…
#> $ Lag1_ActiveEnergyPostariff3   <dbl> NA, 10.86, 26.16, 13.50, 7.98, 14.70, 6.…
#> $ Lag1_ReactiveEnergyPosTotal   <dbl> NA, 0.96, 5.64, 3.00, 2.76, 1.80, 0.36, …
#> $ Lag1_ReactiveEnergyPostariff1 <dbl> NA, 0.72, 0.00, 1.98, 1.74, 1.20, 0.30, …
#> $ Lag1_ReactiveEnergyPostariff2 <dbl> NA, 0.00, 0.00, 0.78, 0.96, 0.30, 0.00, …
#> $ Lag1_ReactiveEnergyPostariff3 <dbl> NA, 0.30, 5.64, 0.24, 0.12, 0.18, 0.06, …
#> $ Lag1_ReactiveEnergyNegTotal   <dbl> NA, 27.66, 50.22, 39.24, 22.32, 34.86, 2…
#> $ Lag1_ReactiveEnergyNegtariff1 <dbl> NA, 18.84, 34.92, 29.64, 12.48, 22.50, 1…
#> $ Lag1_ReactiveEnergyNegtariff2 <dbl> NA, 8.52, 7.20, 7.62, 7.56, 7.50, 8.76, …
#> $ Lag1_ReactiveEnergyNegtariff3 <dbl> NA, 0.24, 8.10, 1.92, 2.28, 4.92, 1.86, …
#> $ `5-MA`                        <dbl> NA, NA, NA, 95.160, 92.556, 82.752, 76.4…
#> $ weekday                       <chr> "Wed", "Thu", "Fri", "Sat", "Sun", "Mon"…
#> $ DayType                       <chr> "Weekday", "Weekday", "Weekend", "Weekda…

df_total_work$Number_of_phases <- df_total_work$Number_of_phases %>% factor()

df_total_work$weekday <- df_total_work$weekday %>% factor()

df_total_work$Target <- df_total_work$Target %>% factor()

df_total_work$DayType <- df_total_work$DayType %>% factor()

df_total_work$Tarrif_category   <- df_total_work$Tarrif_category  %>% factor()

df_total_work$weekday <- 
  factor(df_total_work$weekday, 
         levels = c("Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"),  ordered = TRUE)

glimpse(x = df_total_work)
#> Rows: 2,018
#> Columns: 42
#> $ MeterNumber                   <chr> "1729", "1729", "1729", "1729", "1729", …
#> $ DateShamsi                    <chr> "1401-02-14 00:00:00", "1401-02-15 00:00…
#> $ Date                          <dttm> 2022-05-04, 2022-05-05, 2022-05-06, 202…
#> $ Daily                         <dttm> 2022-05-04, 2022-05-05, 2022-05-06, 202…
#> $ ActiveEnergyPosTotal          <dbl> 985371.4, 985454.1, 985570.4, 985684.1, …
#> $ ActiveEnergyPostariff1        <dbl> 539273.6, 539326.7, 539400.7, 539480.1, …
#> $ ActiveEnergyPostariff2        <dbl> 149320.4, 149339.2, 149355.4, 149376.1, …
#> $ ActiveEnergyPostariff3        <dbl> 296777.3, 296788.2, 296814.4, 296827.9, …
#> $ ActiveEnergyNegTotal          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff1        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff2        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff3        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ReactiveEnergyPosTotal        <dbl> 41099.10, 41100.06, 41105.70, 41108.70, …
#> $ ReactiveEnergyPostariff1      <dbl> 19902.96, 19903.68, 19903.68, 19905.66, …
#> $ ReactiveEnergyPostariff2      <dbl> 9277.56, 9277.56, 9277.56, 9278.34, 9279…
#> $ ReactiveEnergyPostariff3      <dbl> 11918.52, 11918.82, 11924.46, 11924.70, …
#> $ ReactiveEnergyNegTotal        <dbl> 57187.56, 57215.22, 57265.44, 57304.68, …
#> $ ReactiveEnergyNegtariff1      <dbl> 32085.66, 32104.50, 32139.42, 32169.06, …
#> $ ReactiveEnergyNegtariff2      <dbl> 9429.90, 9438.42, 9445.62, 9453.24, 9460…
#> $ ReactiveEnergyNegtariff3      <dbl> 15672.00, 15672.24, 15680.34, 15682.26, …
#> $ Tarrif_category               <fct> Commercial, Commercial, Commercial, Comm…
#> $ Contracted_demand             <dbl> 180, 180, 180, 180, 180, 180, 180, 180, …
#> $ Number_of_phases              <fct> 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3…
#> $ Target                        <fct> Mining, Mining, Mining, Mining, Mining, …
#> $ AvgPFt1                       <dbl> 0.5, 2.0, 0.1, 0.6, 2.0, 3.7, 3.0, 4.0, …
#> $ AvgPFt2                       <dbl> 95.3, 96.9, 93.6, 95.7, 95.6, 94.7, 94.7…
#> $ AvgPFt3                       <dbl> 61.0, 57.1, 71.3, 64.1, 71.6, 60.0, 61.2…
#> $ Lag1_ActiveEnergyPosTotal     <dbl> NA, 82.74, 116.28, 113.70, 68.04, 95.04,…
#> $ Lag1_ActiveEnergyPostariff1   <dbl> NA, 53.16, 73.92, 79.44, 38.82, 60.36, 4…
#> $ Lag1_ActiveEnergyPostariff2   <dbl> NA, 18.72, 16.20, 20.70, 21.30, 19.98, 1…
#> $ Lag1_ActiveEnergyPostariff3   <dbl> NA, 10.86, 26.16, 13.50, 7.98, 14.70, 6.…
#> $ Lag1_ReactiveEnergyPosTotal   <dbl> NA, 0.96, 5.64, 3.00, 2.76, 1.80, 0.36, …
#> $ Lag1_ReactiveEnergyPostariff1 <dbl> NA, 0.72, 0.00, 1.98, 1.74, 1.20, 0.30, …
#> $ Lag1_ReactiveEnergyPostariff2 <dbl> NA, 0.00, 0.00, 0.78, 0.96, 0.30, 0.00, …
#> $ Lag1_ReactiveEnergyPostariff3 <dbl> NA, 0.30, 5.64, 0.24, 0.12, 0.18, 0.06, …
#> $ Lag1_ReactiveEnergyNegTotal   <dbl> NA, 27.66, 50.22, 39.24, 22.32, 34.86, 2…
#> $ Lag1_ReactiveEnergyNegtariff1 <dbl> NA, 18.84, 34.92, 29.64, 12.48, 22.50, 1…
#> $ Lag1_ReactiveEnergyNegtariff2 <dbl> NA, 8.52, 7.20, 7.62, 7.56, 7.50, 8.76, …
#> $ Lag1_ReactiveEnergyNegtariff3 <dbl> NA, 0.24, 8.10, 1.92, 2.28, 4.92, 1.86, …
#> $ `5-MA`                        <dbl> NA, NA, NA, 95.160, 92.556, 82.752, 76.4…
#> $ weekday                       <ord> Wed, Thu, Fri, Sat, Sun, Mon, Tue, Wed, …
#> $ DayType                       <fct> Weekday, Weekday, Weekend, Weekday, Week…

class(df_total_work)
#> [1] "tbl_df"     "tbl"        "data.frame"
dim(df_total_work)
#> [1] 2018   42

# ROSE: random over sampling examples method

library(themis)
#> Warning: package 'themis' was built under R version 4.3.2

glimpse(df_total_work)
#> Rows: 2,018
#> Columns: 42
#> $ MeterNumber                   <chr> "1729", "1729", "1729", "1729", "1729", …
#> $ DateShamsi                    <chr> "1401-02-14 00:00:00", "1401-02-15 00:00…
#> $ Date                          <dttm> 2022-05-04, 2022-05-05, 2022-05-06, 202…
#> $ Daily                         <dttm> 2022-05-04, 2022-05-05, 2022-05-06, 202…
#> $ ActiveEnergyPosTotal          <dbl> 985371.4, 985454.1, 985570.4, 985684.1, …
#> $ ActiveEnergyPostariff1        <dbl> 539273.6, 539326.7, 539400.7, 539480.1, …
#> $ ActiveEnergyPostariff2        <dbl> 149320.4, 149339.2, 149355.4, 149376.1, …
#> $ ActiveEnergyPostariff3        <dbl> 296777.3, 296788.2, 296814.4, 296827.9, …
#> $ ActiveEnergyNegTotal          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff1        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff2        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ActiveEnergyNegtariff3        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
#> $ ReactiveEnergyPosTotal        <dbl> 41099.10, 41100.06, 41105.70, 41108.70, …
#> $ ReactiveEnergyPostariff1      <dbl> 19902.96, 19903.68, 19903.68, 19905.66, …
#> $ ReactiveEnergyPostariff2      <dbl> 9277.56, 9277.56, 9277.56, 9278.34, 9279…
#> $ ReactiveEnergyPostariff3      <dbl> 11918.52, 11918.82, 11924.46, 11924.70, …
#> $ ReactiveEnergyNegTotal        <dbl> 57187.56, 57215.22, 57265.44, 57304.68, …
#> $ ReactiveEnergyNegtariff1      <dbl> 32085.66, 32104.50, 32139.42, 32169.06, …
#> $ ReactiveEnergyNegtariff2      <dbl> 9429.90, 9438.42, 9445.62, 9453.24, 9460…
#> $ ReactiveEnergyNegtariff3      <dbl> 15672.00, 15672.24, 15680.34, 15682.26, …
#> $ Tarrif_category               <fct> Commercial, Commercial, Commercial, Comm…
#> $ Contracted_demand             <dbl> 180, 180, 180, 180, 180, 180, 180, 180, …
#> $ Number_of_phases              <fct> 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3…
#> $ Target                        <fct> Mining, Mining, Mining, Mining, Mining, …
#> $ AvgPFt1                       <dbl> 0.5, 2.0, 0.1, 0.6, 2.0, 3.7, 3.0, 4.0, …
#> $ AvgPFt2                       <dbl> 95.3, 96.9, 93.6, 95.7, 95.6, 94.7, 94.7…
#> $ AvgPFt3                       <dbl> 61.0, 57.1, 71.3, 64.1, 71.6, 60.0, 61.2…
#> $ Lag1_ActiveEnergyPosTotal     <dbl> NA, 82.74, 116.28, 113.70, 68.04, 95.04,…
#> $ Lag1_ActiveEnergyPostariff1   <dbl> NA, 53.16, 73.92, 79.44, 38.82, 60.36, 4…
#> $ Lag1_ActiveEnergyPostariff2   <dbl> NA, 18.72, 16.20, 20.70, 21.30, 19.98, 1…
#> $ Lag1_ActiveEnergyPostariff3   <dbl> NA, 10.86, 26.16, 13.50, 7.98, 14.70, 6.…
#> $ Lag1_ReactiveEnergyPosTotal   <dbl> NA, 0.96, 5.64, 3.00, 2.76, 1.80, 0.36, …
#> $ Lag1_ReactiveEnergyPostariff1 <dbl> NA, 0.72, 0.00, 1.98, 1.74, 1.20, 0.30, …
#> $ Lag1_ReactiveEnergyPostariff2 <dbl> NA, 0.00, 0.00, 0.78, 0.96, 0.30, 0.00, …
#> $ Lag1_ReactiveEnergyPostariff3 <dbl> NA, 0.30, 5.64, 0.24, 0.12, 0.18, 0.06, …
#> $ Lag1_ReactiveEnergyNegTotal   <dbl> NA, 27.66, 50.22, 39.24, 22.32, 34.86, 2…
#> $ Lag1_ReactiveEnergyNegtariff1 <dbl> NA, 18.84, 34.92, 29.64, 12.48, 22.50, 1…
#> $ Lag1_ReactiveEnergyNegtariff2 <dbl> NA, 8.52, 7.20, 7.62, 7.56, 7.50, 8.76, …
#> $ Lag1_ReactiveEnergyNegtariff3 <dbl> NA, 0.24, 8.10, 1.92, 2.28, 4.92, 1.86, …
#> $ `5-MA`                        <dbl> NA, NA, NA, 95.160, 92.556, 82.752, 76.4…
#> $ weekday                       <ord> Wed, Thu, Fri, Sat, Sun, Mon, Tue, Wed, …
#> $ DayType                       <fct> Weekday, Weekday, Weekend, Weekday, Week…
df_total_work_reduced <- df_total_work[, 21:length(df_total_work)] %>% 
  filter(!is.na(Target))

sum(is.na(df_total_work_reduced))
#> [1] 68

df_total_work_reduced <- na.omit(object = df_total_work_reduced)

sum(is.na(df_total_work_reduced))
#> [1] 0
dim(df_total_work_reduced)
#> [1] 1991   22

names(df_total_work_reduced)
#>  [1] "Tarrif_category"               "Contracted_demand"            
#>  [3] "Number_of_phases"              "Target"                       
#>  [5] "AvgPFt1"                       "AvgPFt2"                      
#>  [7] "AvgPFt3"                       "Lag1_ActiveEnergyPosTotal"    
#>  [9] "Lag1_ActiveEnergyPostariff1"   "Lag1_ActiveEnergyPostariff2"  
#> [11] "Lag1_ActiveEnergyPostariff3"   "Lag1_ReactiveEnergyPosTotal"  
#> [13] "Lag1_ReactiveEnergyPostariff1" "Lag1_ReactiveEnergyPostariff2"
#> [15] "Lag1_ReactiveEnergyPostariff3" "Lag1_ReactiveEnergyNegTotal"  
#> [17] "Lag1_ReactiveEnergyNegtariff1" "Lag1_ReactiveEnergyNegtariff2"
#> [19] "Lag1_ReactiveEnergyNegtariff3" "5-MA"                         
#> [21] "weekday"                       "DayType"
names(df_total_work_reduced) <- 
  c("TarCat", "Demand", "NumPhases", "Target", "AvgPFt1", "AvgPFt2", "AvgPFt3", 
    "EnergyPosTotal","EnergyPostariff1", "EnergyPostariff2", "EnergyPostariff3",
    "ReEnergyPosTotal", "ReEnergyPostariff1", "ReEnergyPostariff2", "ReEnergyPostariff3", 
    "ReEnergyNegTotal", "ReEnergyNegtariff1", "ReEnergyNegtariff2", "ReEnergyNegtariff3", 
    "MA", "weekday", "DayType")

df_total_work_reduced <- df_total_work_reduced %>% dplyr::select(-NumPhases)

glimpse(df_total_work_reduced)
#> Rows: 1,991
#> Columns: 21
#> $ TarCat             <fct> Commercial, Commercial, Commercial, Commercial, Com…
#> $ Demand             <dbl> 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 1…
#> $ Target             <fct> Mining, Mining, Mining, Mining, Mining, Mining, Min…
#> $ AvgPFt1            <dbl> 0.6, 2.0, 3.7, 3.0, 4.0, 0.0, 79.8, 3.4, 0.1, 1.0, …
#> $ AvgPFt2            <dbl> 95.7, 95.6, 94.7, 94.7, 95.3, 97.5, 99.5, 91.6, 88.…
#> $ AvgPFt3            <dbl> 64.1, 71.6, 60.0, 61.2, 71.1, 65.2, 70.5, 65.0, 60.…
#> $ EnergyPosTotal     <dbl> 113.70, 68.04, 95.04, 69.72, 67.26, 82.32, 131.10, …
#> $ EnergyPostariff1   <dbl> 79.44, 38.82, 60.36, 44.16, 42.78, 53.94, 87.00, 82…
#> $ EnergyPostariff2   <dbl> 20.70, 21.30, 19.98, 18.66, 19.14, 19.56, 19.32, 18…
#> $ EnergyPostariff3   <dbl> 13.50, 7.98, 14.70, 6.78, 5.46, 8.82, 24.78, 43.38,…
#> $ ReEnergyPosTotal   <dbl> 3.00, 2.76, 1.80, 0.36, 0.12, 1.08, 3.18, 0.48, 5.7…
#> $ ReEnergyPostariff1 <dbl> 1.98, 1.74, 1.20, 0.30, 0.06, 0.78, 2.40, 0.18, 2.7…
#> $ ReEnergyPostariff2 <dbl> 0.78, 0.96, 0.30, 0.00, 0.00, 0.36, 0.78, 0.00, 0.3…
#> $ ReEnergyPostariff3 <dbl> 0.24, 0.12, 0.18, 0.06, 0.00, 0.06, 0.06, 0.24, 2.6…
#> $ ReEnergyNegTotal   <dbl> 39.24, 22.32, 34.86, 29.40, 28.68, 32.16, 52.02, 62…
#> $ ReEnergyNegtariff1 <dbl> 29.64, 12.48, 22.50, 18.78, 19.02, 21.18, 35.70, 36…
#> $ ReEnergyNegtariff2 <dbl> 7.62, 7.56, 7.50, 8.76, 8.04, 7.56, 7.50, 7.98, 7.6…
#> $ ReEnergyNegtariff3 <dbl> 1.92, 2.28, 4.92, 1.86, 1.56, 3.54, 8.76, 17.88, 1.…
#> $ MA                 <dbl> 95.160, 92.556, 82.752, 76.476, 89.088, 98.976, 103…
#> $ weekday            <ord> Sat, Sun, Mon, Tue, Wed, Thu, Fri, Sat, Sun, Mon, T…
#> $ DayType            <fct> Weekday, Weekday, Weekday, Weekday, Weekday, Weekda…

table(df_total_work_reduced$Target)
#> 
#>    Mining NotMining 
#>       355      1636

set.seed(123)
splits <- initial_split(data = df_total_work_reduced, prop = 3/4, strata = Target)
splits
#> <Training/Testing/Total>
#> <1493/498/1991>

df_other <- training(splits)
df_test <- testing(splits)

# training set proportions by children
df_other %>% count(Target) %>% 
  mutate(prop = n/sum(n))
#> # A tibble: 2 × 3
#>   Target        n  prop
#>   <fct>     <int> <dbl>
#> 1 Mining      266 0.178
#> 2 NotMining  1227 0.822

# test set proportions by children
df_test %>% count(Target) %>% 
  mutate(prop = n/sum(n))
#> # A tibble: 2 × 3
#>   Target        n  prop
#>   <fct>     <int> <dbl>
#> 1 Mining       89 0.179
#> 2 NotMining   409 0.821

set.seed(234)
val_set <- validation_split(data = df_other, prop = 0.8, strata = Target)
#> Warning: `validation_split()` was deprecated in rsample 1.2.0.
#> ℹ Please use `initial_validation_split()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.

val_set
#> # Validation Set Split (0.8/0.2)  using stratification 
#> # A tibble: 1 × 2
#>   splits             id        
#>   <list>             <chr>     
#> 1 <split [1193/300]> validation

# Gradient boosting algorithms - catboost
library(tidymodels)
library(treesnip)
library(finetune)
#> Warning: package 'finetune' was built under R version 4.3.2
library(catboost)

cb_model <- boost_tree(mode = "classification",
                       engine = "catboost",
                       mtry = tune(),
                       trees = tune(),
                       min_n = tune(),
                       tree_depth = tune(),
                       learn_rate = tune()) # parameters to be tuned

cb_model
#> Boosted Tree Model Specification (classification)
#> 
#> Main Arguments:
#>   mtry = tune()
#>   trees = tune()
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#> 
#> Computational engine: catboost

cb_wf <- 
  workflow() %>% 
  add_model(cb_model) %>% 
  add_formula(Target ~ .)
cb_wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> Target ~ .
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (classification)
#> 
#> Main Arguments:
#>   mtry = tune()
#>   trees = tune()
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#> 
#> Computational engine: catboost

cb_grid <- grid_latin_hypercube( 
  finalize(object = mtry(), x = df_other),
  trees(),
  min_n(),
  tree_depth(),
  #  loss_reduction(),
  #  sample_size = sample_prop(),  # A number for the number (or proportion) of data
  # that is exposed to the fitting routine. For xgboost, 
  # the sampling is done at each iteration 
  # while C5.0 samples once during training.

  learn_rate(), 
  size = 30   # A single integer for the total number of parameter value combinations
  # returned. If duplicate combinations are generated from this size, the smaller, unique set is returned.
)

finalize(object = mtry(), x = df_other)
#> # Randomly Selected Predictors (quantitative)
#> Range: [1, 21]

cb_grid
#> # A tibble: 30 × 5
#>     mtry trees min_n tree_depth learn_rate
#>    <int> <int> <int>      <int>      <dbl>
#>  1     1  1408    37          6   5.38e- 6
#>  2     7  1086    22          3   1.21e- 4
#>  3    18  1330     5          9   1.26e- 9
#>  4     3   955     6          4   1.42e- 3
#>  5    18  1850    20          4   6.98e- 2
#>  6     9   624    23          2   2.85e- 8
#>  7     6   224     7          5   1.97e- 9
#>  8     5   191     3         13   1.64e-10
#>  9     3  1577    11          3   1.77e- 6
#> 10    17  1528    17          8   4.21e- 3
#> # ℹ 20 more rows

cb_wf <- workflow() %>% 
  add_formula(formula = Target ~ .) %>% 
  add_model(spec = cb_model)

cb_wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> Target ~ .
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (classification)
#> 
#> Main Arguments:
#>   mtry = tune()
#>   trees = tune()
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#> 
#> Computational engine: catboost

cb_res <- tune_grid(      # Model tuning process via grid search
  object = cb_wf, 
  resamples = val_set, 
  grid = cb_grid, 
  control = control_grid(save_pred = TRUE)
)
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.

show_notes(.Last.tune.result)
#> unique notes:
#> ────────────────────────────────────────────────────────────────────────────────
#> Error in `pull_workflow_spec_encoding_tbl()`:
#> ! Exactly 1 model/engine/mode combination must be located.
#> ℹ This is an internal error that was detected in the workflows package.
#>   Please report it at <https://github.com/tidymodels/workflows/issues> with a reprex (<https://tidyverse.org/help/>) and the full backtrace.

cb_res
#> # Tuning results
#> # Validation Set Split (0.8/0.2)  using stratification 
#> # A tibble: 1 × 5
#>   splits             id         .metrics .notes            .predictions
#>   <list>             <chr>      <list>   <list>            <list>      
#> 1 <split [1193/300]> validation <NULL>   <tibble [30 × 3]> <NULL>      
#> 
#> There were issues with some computations:
#> 
#>   - Error(s) x30: Error in `pull_workflow_spec_encoding_tbl()`: ! Exactly 1 model/e...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

cb_res %>% collect_metrics() %>% View()    
#> Error in `estimate_tune_results()`:
#> ! All models failed. Run `show_notes(.Last.tune.result)` for more information.

Created on 2024-02-20 with reprex v2.1.0

Additionally, my machine and session information are as follows:

sessionInfo()
#> R version 4.3.1 (2023-06-16 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 18363)
#> 
#> Matrix products: default
#> 
#> 
#> locale:
#> [1] LC_COLLATE=English_United States.utf8 
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> time zone: Asia/Tehran
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> loaded via a namespace (and not attached):
#>  [1] styler_1.10.2     digest_0.6.34     fastmap_1.1.1     xfun_0.42        
#>  [5] magrittr_2.0.3    glue_1.7.0        R.utils_2.12.3    knitr_1.45       
#>  [9] htmltools_0.5.7   rmarkdown_2.25    lifecycle_1.0.4   cli_3.6.2        
#> [13] R.methodsS3_1.8.2 vctrs_0.6.5       reprex_2.1.0      withr_3.0.0      
#> [17] compiler_4.3.1    R.oo_1.26.0       R.cache_0.16.0    purrr_1.0.2      
#> [21] rstudioapi_0.15.0 tools_4.3.1       evaluate_0.23     yaml_2.3.8       
#> [25] rlang_1.1.3       fs_1.6.3

Created on 2024-02-20 with reprex v2.1.0

simonpcouch commented 7 months ago

Thanks for the issue! The error:

#> Error in `pull_workflow_spec_encoding_tbl()`:
#> ! Exactly 1 model/engine/mode combination must be located.
#> ...

...notes that the package that supplies the catboost model definition, treesnip, has some issue with how it's defined the model. The tidymodels team doesn't maintain treesnip, and unfortunately, there are several open issues about catboost support on their package repository. If catboost does make it to CRAN, we will support catboost in bonsai and resolve this issue while we do so. :)

github-actions[bot] commented 7 months ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.