Open cgoo4 opened 9 months ago
Hello @cgoo4 I finally did it. Would you like to test it and report if this fits your need ? One way to install it is
pak::pkg_install("mlverse/tabnet@feature/case_weight")
Hi @cregouby - Thank you.
Ahead of trying it on my own data, I've made a quick test using the toy lending_club
data. Untuned TabNet and XGBoost models, with and without case weights, show comparable results!
library(tabnet)
library(tidymodels)
library(modeldata)
library(patchwork)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
xgb_rec <- tab_rec |>
step_dummy(term, sub_grade, addr_state, verification_status, emp_length)
tab_mod <- tabnet(epochs = 100) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
xgb_mod <- boost_tree(trees = 100) |>
set_engine("xgboost") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
xgb_wf <- workflow() |>
add_model(xgb_mod) |>
add_recipe(xgb_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
xgb_fit <- xgb_wf |> fit(train)
tab_test <- tab_fit |> augment(test)
xgb_test <- xgb_fit |> augment(test)
p1 <- tab_test |>
pr_curve(Class, .pred_good, case_weights = case_wts) |>
autoplot() +
ggtitle("TabNet with Case Weights") +
theme(plot.title = element_text(size = 9))
p2 <- tab_test |>
pr_curve(Class, .pred_good) |>
autoplot() +
ggtitle("TabNet WITHOUT") +
theme(plot.title = element_text(size = 9))
p3 <- xgb_test |>
pr_curve(Class, .pred_good, case_weights = case_wts) |>
autoplot() +
ggtitle("XGBoost with Case Weights") +
theme(plot.title = element_text(size = 9))
p4 <- xgb_test |>
pr_curve(Class, .pred_good) |>
autoplot() +
ggtitle("XGBoost WITHOUT") +
theme(plot.title = element_text(size = 9))
p1 + p2 + p3 + p4
Created on 2024-02-18 with reprex v2.1.0
In 0.6.0.9000 I'm getting the message Configured
weightswill not be used
:
(It's the same example as per above where the weights were being passed along from the workflow.)
library(tabnet)
library(tidymodels)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
tab_mod <- tabnet(epochs = 10) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
#> Configured `weights` will not be used
tab_test <- tab_fit |> augment(test)
Created on 2024-08-09 with reprex v2.1.1
Hello @cgoo4,
I added the message on purpose and it is maybe misleading. The meaning is 'tabnet model will be fit without using the case_weights variable.' as this is the actual usage of case_weights variable by tabnet, they are let appart for later-on usage by other downstream tydimodel packages.
Any proposal for a more informative message ?
Hi @cregouby - Thank you for clarifying.
If it's possible to set case_weights more than one way, e.g. in a tidymodels workflow()
and also in tabnet_fit()
, then maybe the message could say the latter is being overridden by the former?
Would it be possible to add support for case weights in TabNet?
This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.
(I will probably upsample the minority class in the meantime as an alternative approach.)
This would be the desired workflow:
Created on 2024-01-12 with reprex v2.0.2