mlverse / tabnet

An R implementation of TabNet
https://mlverse.github.io/tabnet/
Other
108 stars 13 forks source link

Error with parameter tuning #85

Closed JunaidMB closed 2 years ago

JunaidMB commented 2 years ago

Hello,

I'm following the Tabnet tutorial found here with a different and significantly smaller dataset and I'm having some issues related to the hyperparameter tuning. I'm getting the following error:

Error in `dplyr::filter()`:
! Problem while computing `..1 = .metric == metric`.
Caused by error:
! object '.metric' not found
Run `rlang::last_error()` to see where the error occurred.
Warning message:
All models failed. See the `.notes` column. 

When I go into the .notes column I see the following error message:

"internal: Error in UseMethod(\"filter\"): no applicable method for 'filter' applied to an object of class \"NULL\""

I'm not sure why this error is occurring, if it was related to the size or column types of the dataset I would expect there to be different error messages, however this seems like something related to how tabnet is interacting with tidymodels?

The full code to reproduce the error is:

library(tidymodels)
library(parameters)
library(skimr)
library(remotes)
library(tidyverse)
library(parallel)
library(doParallel)
library(vip)
library(themis)
library(lme4)
library(BradleyTerry2)
library(finetune)
library(butcher)
library(lobstr)
library(lubridate)
library(NHSRdatasets)
library(torch)
library(tabnet)
library(yardstick)

set.seed(777)
torch_manual_seed(777)

# Read in data ----
##  a stranded patient is a patient that has been in hospital for longer than 7 days and we also call these Long Waiters.
strand_pat <- NHSRdatasets::stranded_data %>% 
  setNames(c("stranded_class", "age", "care_home_ref_flag", "medically_safe_flag", 
             "hcop_flag", "needs_mental_health_support_flag", "previous_care_in_last_12_month", "admit_date", "frail_descrip")) %>% 
  mutate(stranded_class = factor(stranded_class),
         admit_date = as.Date(admit_date, format = "%d/%m/%Y")) %>% 
  drop_na()

# Partition into training and test data splits ----
split <- initial_split(strand_pat)
train_data <- training(split)
test_data <- testing(split)  

# Create Recipe ----
## Define Recipe to be applied to the dataset
stranded_rec <- 
  recipe(stranded_class ~ ., data = train_data) %>% 
  # Make a day of week and month feature from admit date and remove raw admit date
  step_date(admit_date, features = c("dow", "month")) %>% 
  step_rm(admit_date) %>% 
  # Upsample minority (positive) class
  themis::step_upsample(stranded_class, over_ratio = as.numeric(upsample_ratio)) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

## Prepare and Bake recipe on training and test data
stranded_recipe_prep <- prep(stranded_rec, training = train_data)

stranded_train_bake <- bake(stranded_recipe_prep, new_data = NULL)
stranded_test_bake <- bake(stranded_recipe_prep, new_data = test_data)

# hyperparameter settings (apart from epochs) as per the TabNet paper (TabNet-S)
tabnet_model <- tabnet(epochs = 1, batch_size = 256, decision_width = tune(), attention_width = tune(),
              num_steps = tune(), penalty = 0.000001, virtual_batch_size = 256, momentum = 0.6,
              feature_reusage = 1.5, learn_rate = tune()) %>%
  set_engine("torch", verbose = TRUE) %>%
  set_mode("classification")

# Create Workflow to connect recipe and model
tabnet_workflow <- workflow() %>%
  add_model(tabnet_model) %>%
  add_recipe(stranded_rec)

# Specify parameter tuning grid
grid <-
  tabnet_workflow %>%
  tune::parameters() %>%
  update(
    decision_width = decision_width(range = c(20, 40)),
    attention_width = attention_width(range = c(20, 40)),
    num_steps = num_steps(range = c(4, 6)),
    learn_rate = learn_rate(range = c(-2.5, -1))
  ) %>%
  grid_max_entropy(size = 8)

folds <- vfold_cv(train_data, v = 5)
set.seed(777)

res <- tabnet_workflow %>% 
  tune_race_win_loss(
    resamples = folds,
    grid = grid,
    metrics = metric_set(accuracy),
    control = control_race()
  )

Any help would be much appreciated. Thanks!

cregouby commented 2 years ago

Hello @JunaidMB

My run of your code with the github version of the package do not raise any issue :

[Epoch 001] Loss: 1.145955                                                                                   
[Epoch 001] Loss: 1.350600                                                                                   
[Epoch 001] Loss: 1.083805                                                                                   
[Epoch 001] Loss: 1.049350                                                                                   
[Epoch 001] Loss: 1.270870                                                                                   
[Epoch 001] Loss: 1.425476                                                                                   
[Epoch 001] Loss: 0.762844                                                                                   
[Epoch 001] Loss: 1.036789                                                                                   
...
[Epoch 001] Loss: 0.926028                                                                                   
[Epoch 001] Loss: 0.965075                                                                                   
[Epoch 001] Loss: 1.117391                                                                                   
> res
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits            id    .order .metrics         .notes          
  <list>            <chr>  <int> <list>           <list>          
1 <split [419/105]> Fold1      3 <tibble [8 × 8]> <tibble [0 × 1]>
2 <split [419/105]> Fold3      1 <tibble [8 × 8]> <tibble [0 × 1]>
3 <split [419/105]> Fold4      2 <tibble [8 × 8]> <tibble [0 × 1]>
4 <split [419/105]> Fold2      4 <tibble [8 × 8]> <tibble [0 × 1]>
5 <split [420/104]> Fold5      5 <tibble [8 × 8]> <tibble [0 × 1]>

Could you try latest github version of the package ?

JunaidMB commented 2 years ago

Hi @cregouby

Using the latest github version of the package has made it work! Thank you very much, I have a full workflow using the tidymodels here (for reference):

https://github.com/JunaidMB/playing_with_tabnet/blob/master/tabnet_nhs_stranded.R

I would consider this issue closed!

wtbxsjy commented 2 years ago

I have got the same error message when fitting with torch. And after repeat loading and detaching loaded packages, I have found that the error is related with future packages (furrr, doFuture and etc.). In my case, the error will occur when set plan(multisession), and if you set plan(sequential), the example will work properly. You might set parallel somewhere else before running the example. hope this could help.