mlverse / tabnet

An R implementation of TabNet
https://mlverse.github.io/tabnet/
Other
108 stars 13 forks source link

Error: "'xxx' is not an exported object from 'namespace:dials'" when trying to tune the hyperparameters of TabNet #159

Closed AKALeon closed 3 months ago

AKALeon commented 3 months ago

I encountered the error when trying to tune the hyperparameters of TabNet on both macOS M1 and Posit cloud.

The code is as follows:

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

rec <- recipe(Class ~ ., train) %>%
  step_normalize(all_numeric())

mod <- tabnet(epochs = 1, batch_size = 16384, decision_width = tune(), attention_width = tune(),
              num_steps = tune(), penalty = 0.000001, virtual_batch_size = 512, momentum = 0.6,
              feature_reusage = 1.5, learn_rate = tune()) %>%
  set_engine("torch", verbose = TRUE) %>%
  set_mode("classification")

wf <- workflow() %>%
  add_model(mod) %>%
  add_recipe(rec)

grid <-
  wf %>%
  parameters() %>%
  update(
    decision_width = decision_width(range = c(20, 40)),
    attention_width = attention_width(range = c(20, 40)),
    num_steps = num_steps(range = c(4, 6)),
    learn_rate = learn_rate(range = c(-2.5, -1))
  ) %>%
  grid_max_entropy(size = 8)

The error occur when I run:

wf %>%
  parameters()
Error in `mutate()`:
ℹ In argument: `object = purrr::map(call_info, eval_call_info)`.
Caused by error in `purrr::map()`:
ℹ In index: 1.
Caused by error in `.f()`:
! Error when calling decision_width(): Error : 'decision_width' is not an exported object from 'namespace:dials'

When I tried to use tune_grid or tune_anova_race, I also encountered the same error.

cregouby commented 3 months ago

Hello @AKALeon Please give me some time to have a look at this...