mlverse / tabnet

An R implementation of TabNet
https://mlverse.github.io/tabnet/
Other
108 stars 13 forks source link

Issues with `device = "mps"` on Sonoma OS #144

Open cgoo4 opened 10 months ago

cgoo4 commented 10 months ago

Running the example below in a fresh R session, tabnet_pretrain() works with device = "mps", but tabnet_fit() hangs (no message) and I need to Terminate R to recover. Session info attached.

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id")), new_role = "predictor")

set.seed(1)

tab_pre <- tab_rec |> 
  tabnet_pretrain(train, device = "mps", checkpoint_epochs = 2)

tab_fit <- tab_rec |>
  tabnet_fit(train, tabnet_model = tab_pre, from_epoch = 2, device = "cpu") # hangs with "mps"

test |> bind_cols(predict(tab_fit, test))
#> # A tibble: 2,465 × 24
#>    funded_amnt term    int_rate sub_grade addr_state verification_status
#>          <int> <fct>      <dbl> <fct>     <fct>      <fct>              
#>  1       10000 term_36    11.5  B5        TX         Source_Verified    
#>  2        7000 term_36    13.0  C2        CA         Source_Verified    
#>  3       35000 term_36    11.5  B5        TN         Source_Verified    
#>  4       15000 term_36    10.8  B4        TX         Not_Verified       
#>  5       27200 term_60    10.8  B4        NC         Not_Verified       
#>  6       12000 term_36    14.5  C4        OR         Source_Verified    
#>  7       15025 term_36    13.7  C3        MA         Source_Verified    
#>  8       20000 term_36     5.32 A1        WI         Not_Verified       
#>  9       20000 term_36    12.0  C1        VA         Verified           
#> 10       10000 term_36    10.8  B4        NC         Verified           
#> # ℹ 2,455 more rows
#> # ℹ 18 more variables: annual_inc <dbl>, emp_length <fct>, delinq_2yrs <int>,
#> #   inq_last_6mths <int>, revol_util <dbl>, acc_now_delinq <int>,
#> #   open_il_6m <int>, open_il_12m <int>, open_il_24m <int>, total_bal_il <int>,
#> #   all_util <int>, inq_fi <int>, inq_last_12m <int>, delinq_amnt <int>,
#> #   num_il_tl <int>, total_il_high_credit_limit <int>, Class <fct>,
#> #   .pred_class <fct>

Created on 2024-01-12 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.3.2 (2023-10-31) #> os macOS Sonoma 14.2.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Europe/London #> date 2024-01-12 #> pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [2] CRAN (R 4.3.0) #> bit 4.0.5 2022-11-15 [2] CRAN (R 4.3.0) #> bit64 4.0.5 2020-08-30 [2] CRAN (R 4.3.0) #> broom * 1.0.5 2023-06-09 [2] CRAN (R 4.3.0) #> callr 3.7.3 2022-11-02 [2] CRAN (R 4.3.0) #> class 7.3-22 2023-05-03 [2] CRAN (R 4.3.2) #> cli 3.6.2 2023-12-11 [1] CRAN (R 4.3.1) #> codetools 0.2-19 2023-02-01 [2] CRAN (R 4.3.2) #> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.0) #> coro 1.0.3 2022-07-19 [2] CRAN (R 4.3.0) #> data.table 1.14.10 2023-12-08 [1] CRAN (R 4.3.1) #> dials * 1.2.0 2023-04-03 [1] CRAN (R 4.3.0) #> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.1) #> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0) #> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.3.1) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.0) #> evaluate 0.23 2023-11-01 [2] CRAN (R 4.3.1) #> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.3.1) #> fastmap 1.1.1 2023-02-24 [2] CRAN (R 4.3.0) #> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0) #> fs 1.6.3 2023-07-20 [2] CRAN (R 4.3.0) #> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0) #> future 1.33.1 2023-12-22 [1] CRAN (R 4.3.1) #> future.apply 1.11.1 2023-12-21 [1] CRAN (R 4.3.1) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0) #> ggplot2 * 3.4.4 2023-10-12 [1] CRAN (R 4.3.1) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0) #> glue 1.7.0 2024-01-09 [1] CRAN (R 4.3.1) #> gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0) #> gtable 0.3.4 2023-08-21 [1] CRAN (R 4.3.0) #> hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.0) #> htmltools 0.5.7 2023-11-03 [2] CRAN (R 4.3.1) #> infer * 1.0.5 2023-09-06 [2] CRAN (R 4.3.0) #> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0) #> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0) #> jsonlite 1.8.8 2023-12-04 [2] CRAN (R 4.3.1) #> knitr 1.45 2023-10-30 [2] CRAN (R 4.3.1) #> lattice 0.22-5 2023-10-24 [2] CRAN (R 4.3.1) #> lava 1.7.3 2023-11-04 [1] CRAN (R 4.3.1) #> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0) #> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.1) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0) #> lubridate 1.9.3 2023-09-27 [1] CRAN (R 4.3.1) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0) #> MASS 7.3-60 2023-05-04 [2] CRAN (R 4.3.2) #> Matrix 1.6-4 2023-11-30 [2] CRAN (R 4.3.1) #> modeldata * 1.2.0 2023-08-09 [2] CRAN (R 4.3.0) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.0) #> nnet 7.3-19 2023-05-03 [2] CRAN (R 4.3.2) #> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0) #> parsnip * 1.1.1 2023-08-17 [1] CRAN (R 4.3.0) #> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0) #> processx 3.8.3 2023-12-10 [2] CRAN (R 4.3.1) #> prodlim 2023.08.28 2023-08-28 [1] CRAN (R 4.3.0) #> ps 1.7.5 2023-04-18 [2] CRAN (R 4.3.0) #> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.0) #> R.cache 0.16.0 2022-07-21 [2] CRAN (R 4.3.0) #> R.methodsS3 1.8.2 2022-06-13 [2] CRAN (R 4.3.0) #> R.oo 1.25.0 2022-06-12 [2] CRAN (R 4.3.0) #> R.utils 2.12.3 2023-11-18 [2] CRAN (R 4.3.1) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0) #> Rcpp 1.0.12 2024-01-09 [1] CRAN (R 4.3.1) #> recipes * 1.0.9 2023-12-13 [1] CRAN (R 4.3.1) #> reprex 2.0.2 2022-08-17 [2] CRAN (R 4.3.0) #> rlang 1.1.3 2024-01-10 [1] CRAN (R 4.3.1) #> rmarkdown 2.25 2023-09-18 [2] CRAN (R 4.3.1) #> rpart 4.1.23 2023-12-05 [2] CRAN (R 4.3.1) #> rsample * 1.2.0 2023-08-23 [1] CRAN (R 4.3.0) #> rstudioapi 0.15.0 2023-07-07 [2] CRAN (R 4.3.0) #> safetensors 0.1.2 2023-09-12 [2] CRAN (R 4.3.0) #> scales * 1.2.1 2022-08-20 [1] CRAN (R 4.3.2) #> sessioninfo 1.2.2 2021-12-06 [2] CRAN (R 4.3.0) #> styler 1.10.2 2023-08-29 [2] CRAN (R 4.3.0) #> survival 3.5-7 2023-08-14 [2] CRAN (R 4.3.2) #> tabnet * 0.5.0.9000 2024-01-11 [1] Github (mlverse/tabnet@962bafa) #> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0) #> tidymodels * 1.1.1 2023-08-24 [2] CRAN (R 4.3.0) #> tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.3.0) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.0) #> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.0) #> timeDate 4032.109 2023-12-14 [1] CRAN (R 4.3.1) #> torch 0.12.0 2024-01-05 [1] Github (mlverse/torch@23071c1) #> tune * 1.1.2 2023-08-23 [1] CRAN (R 4.3.0) #> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.1) #> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.3.1) #> withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1) #> workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.3.0) #> workflowsets * 1.0.1 2023-04-06 [2] CRAN (R 4.3.0) #> xfun 0.41 2023-11-01 [2] CRAN (R 4.3.1) #> yaml 2.3.8 2023-12-11 [2] CRAN (R 4.3.1) #> yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.3.0) #> zeallot 0.1.0 2018-01-28 [2] CRAN (R 4.3.0) #> #> [1] /Users/carlgoodwin/Library/R/arm64/4.3/library #> [2] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
cregouby commented 10 months ago

Hello @cgoo4

It smells like a silent OOM of the MPS device. Running it from the terminal or as a test case, you may eventually see

! MPS backend out of memory (MPS allocated: 0 bytes, other allocations: 0 bytes, max allowed: 1.70 GB). Tried to allocate 0 bytes on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).
Exception raised from alloc_buffer_block at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSAllocator.mm:235 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >) + 81 (0x105c22ca1 in libc10.dylib)

I would recommend you to drastically lower the virtual_batch_size = and eventually batch_size = in your tabnet_config()

Let me know if it helps

cgoo4 commented 10 months ago

Hi @cregouby - I've tried progressively reducing virtual_batch_size to 16, but tabnet_fit still hangs. tabnet_pretrain runs using the defaults. Both produce the message.

I'm using an M2 Max with 64 GB memory.

image
cregouby commented 10 months ago

Hello @cgoo4 The problem is not linked to using tabnet_model = as the second run of whatever training on device="mps" hits the issue.

There are a lot of opened issues currently on device="mps" on sonoma OS (tensorflow, pytorch, ...) so I guess there is nothing we can do until the apple dev fix the MPS code.