mlverse / tabnet

An R implementation of TabNet
https://mlverse.github.io/tabnet/
Other
108 stars 13 forks source link

Feature request for `case-weights` #145

Open cgoo4 opened 9 months ago

cgoo4 commented 9 months ago

Would it be possible to add support for case weights in TabNet?

This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.

(I will probably upsample the minority class in the meantime as an alternative approach.)

This would be the desired workflow:

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

set.seed(1)

tab_mod <- tabnet(epochs = 10) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
#> Error in `check_case_weights()`:
#> ! Case weights are not enabled by the underlying model implementation.
#> Backtrace:
#>      ▆
#>   1. ├─generics::fit(tab_wf, train)
#>   2. └─workflows:::fit.workflow(tab_wf, train)
#>   3.   └─workflows::.fit_model(workflow, control)
#>   4.     ├─generics::fit(action_model, workflow = workflow, control = control)
#>   5.     └─workflows:::fit.action_model(...)
#>   6.       └─workflows:::fit_from_xy(spec, mold, case_weights, control_parsnip)
#>   7.         ├─generics::fit_xy(...)
#>   8.         └─parsnip::fit_xy.model_spec(...)
#>   9.           └─parsnip:::check_case_weights(case_weights, object)
#>  10.             └─rlang::abort("Case weights are not enabled by the underlying model implementation.")

Created on 2024-01-12 with reprex v2.0.2

cregouby commented 8 months ago

Hello @cgoo4 I finally did it. Would you like to test it and report if this fits your need ? One way to install it is

pak::pkg_install("mlverse/tabnet@feature/case_weight")
cgoo4 commented 8 months ago

Hi @cregouby - Thank you.

Ahead of trying it on my own data, I've made a quick test using the toy lending_club data. Untuned TabNet and XGBoost models, with and without case weights, show comparable results!

library(tabnet)
library(tidymodels)
library(modeldata)
library(patchwork)

data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

xgb_rec <- tab_rec |> 
  step_dummy(term, sub_grade, addr_state, verification_status, emp_length)

tab_mod <- tabnet(epochs = 100) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

xgb_mod <- boost_tree(trees = 100) |> 
  set_engine("xgboost") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

xgb_wf <- workflow() |> 
  add_model(xgb_mod) |> 
  add_recipe(xgb_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
xgb_fit <- xgb_wf |> fit(train)

tab_test <- tab_fit |> augment(test)
xgb_test <- xgb_fit |> augment(test)

p1 <- tab_test |> 
  pr_curve(Class, .pred_good, case_weights = case_wts) |> 
  autoplot() +
  ggtitle("TabNet with Case Weights") +
  theme(plot.title = element_text(size = 9))

p2 <- tab_test |> 
  pr_curve(Class, .pred_good) |> 
  autoplot() +
  ggtitle("TabNet WITHOUT") +
  theme(plot.title = element_text(size = 9))

p3 <- xgb_test |> 
  pr_curve(Class, .pred_good, case_weights = case_wts) |> 
  autoplot() +
  ggtitle("XGBoost with Case Weights") +
  theme(plot.title = element_text(size = 9))

p4 <- xgb_test |> 
  pr_curve(Class, .pred_good) |> 
  autoplot() +
  ggtitle("XGBoost WITHOUT") +
  theme(plot.title = element_text(size = 9))

p1 + p2 + p3 + p4

Created on 2024-02-18 with reprex v2.1.0

cgoo4 commented 2 months ago

In 0.6.0.9000 I'm getting the message Configuredweightswill not be used:

(It's the same example as per above where the weights were being passed along from the workflow.)

library(tabnet)
library(tidymodels)

data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

tab_mod <- tabnet(epochs = 10) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
#> Configured `weights` will not be used

tab_test <- tab_fit |> augment(test)

Created on 2024-08-09 with reprex v2.1.1

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.4.1 (2024-06-14) #> os macOS Sonoma 14.5 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Europe/London #> date 2024-08-09 #> pandoc 3.2.1 @ /opt/homebrew/bin/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.5.0 2024-05-23 [1] CRAN (R 4.4.0) #> bit 4.0.5 2022-11-15 [1] CRAN (R 4.4.0) #> bit64 4.0.5 2020-08-30 [1] CRAN (R 4.4.0) #> broom * 1.0.6 2024-05-17 [1] CRAN (R 4.4.0) #> callr 3.7.6 2024-03-25 [1] CRAN (R 4.4.0) #> class 7.3-22 2023-05-03 [2] CRAN (R 4.4.1) #> cli 3.6.3 2024-06-21 [1] CRAN (R 4.4.0) #> codetools 0.2-20 2024-03-31 [2] CRAN (R 4.4.1) #> colorspace 2.1-1 2024-07-26 [1] CRAN (R 4.4.0) #> coro 1.0.4 2024-03-11 [1] CRAN (R 4.4.0) #> data.table 1.15.4 2024-03-30 [1] CRAN (R 4.4.0) #> dials * 1.3.0 2024-07-30 [1] CRAN (R 4.4.0) #> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.4.0) #> digest 0.6.36 2024-06-23 [1] CRAN (R 4.4.0) #> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.4.0) #> evaluate 0.24.0 2024-06-10 [1] CRAN (R 4.4.0) #> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.4.0) #> fastmap 1.2.0 2024-05-15 [1] CRAN (R 4.4.0) #> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.4.0) #> fs 1.6.4 2024-04-25 [1] CRAN (R 4.4.0) #> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.4.0) #> future 1.34.0 2024-07-29 [1] CRAN (R 4.4.0) #> future.apply 1.11.2 2024-03-28 [1] CRAN (R 4.4.0) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.4.0) #> ggplot2 * 3.5.1 2024-04-23 [1] CRAN (R 4.4.0) #> globals 0.16.3 2024-03-08 [1] CRAN (R 4.4.0) #> glue 1.7.0 2024-01-09 [1] CRAN (R 4.4.0) #> gower 1.0.1 2022-12-22 [1] CRAN (R 4.4.0) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.4.0) #> gtable 0.3.5 2024-04-22 [1] CRAN (R 4.4.0) #> hardhat 1.4.0 2024-06-02 [1] CRAN (R 4.4.0) #> htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.4.0) #> infer * 1.0.7 2024-03-25 [1] CRAN (R 4.4.0) #> ipred 0.9-15 2024-07-18 [1] CRAN (R 4.4.0) #> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.4.0) #> jsonlite 1.8.8 2023-12-04 [1] CRAN (R 4.4.0) #> knitr 1.48 2024-07-07 [1] CRAN (R 4.4.0) #> lattice 0.22-6 2024-03-20 [2] CRAN (R 4.4.1) #> lava 1.8.0 2024-03-05 [1] CRAN (R 4.4.0) #> lhs 1.2.0 2024-06-30 [1] CRAN (R 4.4.0) #> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.4.0) #> listenv 0.9.1 2024-01-29 [1] CRAN (R 4.4.0) #> lubridate 1.9.3 2023-09-27 [1] CRAN (R 4.4.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.4.0) #> MASS 7.3-60.2 2024-04-26 [2] CRAN (R 4.4.1) #> Matrix 1.7-0 2024-04-26 [2] CRAN (R 4.4.1) #> modeldata * 1.4.0 2024-06-19 [1] CRAN (R 4.4.0) #> munsell 0.5.1 2024-04-01 [1] CRAN (R 4.4.0) #> nnet 7.3-19 2023-05-03 [2] CRAN (R 4.4.1) #> parallelly 1.38.0 2024-07-27 [1] CRAN (R 4.4.0) #> parsnip * 1.2.1 2024-03-22 [1] CRAN (R 4.4.0) #> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.4.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.4.0) #> processx 3.8.4 2024-03-16 [1] CRAN (R 4.4.0) #> prodlim 2024.06.25 2024-06-24 [1] CRAN (R 4.4.0) #> ps 1.7.7 2024-07-02 [1] CRAN (R 4.4.0) #> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.4.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.4.0) #> Rcpp 1.0.13 2024-07-17 [1] CRAN (R 4.4.0) #> recipes * 1.1.0 2024-07-04 [1] CRAN (R 4.4.0) #> reprex 2.1.1 2024-07-06 [1] CRAN (R 4.4.0) #> rlang 1.1.4 2024-06-04 [1] CRAN (R 4.4.0) #> rmarkdown 2.27 2024-05-17 [1] CRAN (R 4.4.0) #> rpart 4.1.23 2023-12-05 [2] CRAN (R 4.4.1) #> rsample * 1.2.1 2024-03-25 [1] CRAN (R 4.4.0) #> rstudioapi 0.16.0 2024-03-24 [1] CRAN (R 4.4.0) #> safetensors 0.1.2 2023-09-12 [1] CRAN (R 4.4.0) #> scales * 1.3.0 2023-11-28 [1] CRAN (R 4.4.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.4.0) #> survival 3.6-4 2024-04-24 [2] CRAN (R 4.4.1) #> tabnet * 0.6.0.9000 2024-08-09 [1] Github (mlverse/tabnet@c8c82d2) #> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.4.0) #> tidymodels * 1.2.0 2024-03-25 [1] CRAN (R 4.4.0) #> tidyr * 1.3.1 2024-01-24 [1] CRAN (R 4.4.0) #> tidyselect 1.2.1 2024-03-11 [1] CRAN (R 4.4.0) #> timechange 0.3.0 2024-01-18 [1] CRAN (R 4.4.0) #> timeDate 4032.109 2023-12-14 [1] CRAN (R 4.4.0) #> torch 0.13.0 2024-05-21 [1] CRAN (R 4.4.0) #> tune * 1.2.1 2024-04-18 [1] CRAN (R 4.4.0) #> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.4.0) #> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.4.0) #> withr 3.0.1 2024-07-31 [1] CRAN (R 4.4.0) #> workflows * 1.1.4 2024-02-19 [1] CRAN (R 4.4.0) #> workflowsets * 1.1.0 2024-03-21 [1] CRAN (R 4.4.0) #> xfun 0.46 2024-07-18 [1] CRAN (R 4.4.0) #> yaml 2.3.10 2024-07-26 [1] CRAN (R 4.4.0) #> yardstick * 1.3.1 2024-03-21 [1] CRAN (R 4.4.0) #> zeallot 0.1.0 2018-01-28 [1] CRAN (R 4.4.0) #> #> [1] /Users/carlgoodwin/Library/R/arm64/4.4/library #> [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
cregouby commented 2 months ago

Hello @cgoo4,

I added the message on purpose and it is maybe misleading. The meaning is 'tabnet model will be fit without using the case_weights variable.' as this is the actual usage of case_weights variable by tabnet, they are let appart for later-on usage by other downstream tydimodel packages.

Any proposal for a more informative message ?

cgoo4 commented 2 months ago

Hi @cregouby - Thank you for clarifying.

If it's possible to set case_weights more than one way, e.g. in a tidymodels workflow() and also in tabnet_fit(), then maybe the message could say the latter is being overridden by the former?