mlverse / tabnet

An R implementation of TabNet
https://mlverse.github.io/tabnet/
Other
108 stars 13 forks source link

GPU not in use with tidymodels and tabnet #80

Closed vidarsumo closed 2 years ago

vidarsumo commented 2 years ago

I'm trying out tabnet in R which has torch backend.

I'm using tidymodels to tune a set of hyperparameters but I'm not sure if it's using the GPU or not. This is on a Azure VM with cuda 10.2 and cudnn 7.6. and NVIDIA V100 GPU.

    tabnet_spec <- tabnet(
        mode               = "regression",
        epochs             = 5,
        num_steps          = tune(),
        feature_reusage    = tune(),
        learn_rate         = tune(),
        batch_size         = tune(),
        virtual_batch_size = tune()
        ) %>%
        set_engine("torch")

    wflw_tabnet <- workflow() %>%
        add_model(tabnet_spec) %>%
        add_recipe(recipe)

    tabnet_tune_results <- tune_race_anova(
        object     = wflw_tabnet,
        resamples  = resamples_kfold,
        param_info = parameters(wflw_tabnet),
        grid       = 10,
        control    = control_race(verbose = TRUE)
    )

Running this code the GPU load is around 0%, somtimes goes to 5% but straight back down to 0% while the CPU is > 80%. Is there anything I need to do for the GPU to be used. I also tried to add dev = "cuda:0" to set_engine() but without success, i.e. GPU load still mostly around 0%.

image

Cuda version

C:\Users\vidar>nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:32:27_Pacific_Daylight_Time_2019
Cuda compilation tools, release 10.2, V10.2.89

Session info:

> sessioninfo::session_info()
- Session info ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 setting  value
 version  R version 4.1.1 (2021-08-10)
 os       Windows Server x64 (build 17763)
 system   x86_64, mingw32
 ui       RStudio
 language (EN)
 collate  English_United States.1252
 ctype    English_United States.1252
 tz       GMT
 date     2022-01-19
 rstudio  1.4.1717 Juliet Rose (desktop)
 pandoc   2.14.2 @ C:\\PROGRA~3\\CHOCOL~1\\bin\\pandoc.exe

- Packages -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 ! package            * version    date (UTC) lib source
   assertthat           0.2.1      2019-03-21 [1] CRAN (R 4.1.1)
   AzureAuth            1.3.3      2021-09-13 [1] CRAN (R 4.1.1)
   AzureGraph           1.3.2      2021-11-28 [1] CRAN (R 4.1.2)
   AzureKeyVault      * 1.0.5      2021-09-16 [1] CRAN (R 4.1.1)
   AzureRMR             2.4.3      2021-10-23 [1] CRAN (R 4.1.1)
   AzureStor          * 3.6.0      2021-12-17 [1] CRAN (R 4.1.2)
   backports            1.4.1      2021-12-13 [1] CRAN (R 4.1.2)
   bit                  4.0.4      2020-08-04 [1] CRAN (R 4.1.1)
   bit64                4.0.5      2020-08-30 [1] CRAN (R 4.1.1)
   broom              * 0.7.11     2022-01-03 [1] CRAN (R 4.1.2)
   callr                3.7.0      2021-04-20 [1] CRAN (R 4.1.1)
   cellranger           1.1.0      2016-07-27 [1] CRAN (R 4.1.1)
   class                7.3-20     2022-01-13 [1] CRAN (R 4.1.2)
   cli                  3.1.0      2021-10-27 [1] CRAN (R 4.1.1)
   codetools            0.2-18     2020-11-04 [1] CRAN (R 4.1.1)
   colorspace           2.0-2      2021-06-24 [1] CRAN (R 4.1.1)
   coro                 1.0.2      2021-12-03 [1] CRAN (R 4.1.2)
   crayon               1.4.2      2021-10-29 [1] CRAN (R 4.1.1)
   curl                 4.3.2      2021-06-23 [1] CRAN (R 4.1.1)
   DBI                  1.1.2      2021-12-20 [1] CRAN (R 4.1.2)
   dbplyr               2.1.1      2021-04-06 [1] CRAN (R 4.1.1)
   dials              * 0.0.10     2021-09-10 [1] CRAN (R 4.1.1)
   DiceDesign           1.9        2021-02-13 [1] CRAN (R 4.1.1)
   digest               0.6.29     2021-12-01 [1] CRAN (R 4.1.2)
   doParallel         * 1.0.16     2020-10-16 [1] CRAN (R 4.1.1)
   dplyr              * 1.0.7      2021-06-18 [1] CRAN (R 4.1.1)
   ellipsis             0.3.2      2021-04-29 [1] CRAN (R 4.1.1)
   fansi                1.0.2      2022-01-14 [1] CRAN (R 4.1.1)
   finetune           * 0.1.0      2021-07-21 [1] CRAN (R 4.1.1)
   forcats            * 0.5.1      2021-01-27 [1] CRAN (R 4.1.1)
   foreach            * 1.5.1      2020-10-15 [1] CRAN (R 4.1.1)
   forecast           * 8.16       2022-01-10 [1] CRAN (R 4.1.2)
   fracdiff             1.5-1      2020-01-24 [1] CRAN (R 4.1.1)
   fs                   1.5.2      2021-12-08 [1] CRAN (R 4.1.2)
   furrr                0.2.3      2021-06-25 [1] CRAN (R 4.1.1)
   future               1.23.0     2021-10-31 [1] CRAN (R 4.1.1)
   future.apply         1.8.1      2021-08-10 [1] CRAN (R 4.1.1)
   generics             0.1.1      2021-10-25 [1] CRAN (R 4.1.1)
   ggplot2            * 3.3.5      2021-06-25 [1] CRAN (R 4.1.1)
   ggtext               0.1.1      2020-12-17 [1] CRAN (R 4.1.1)
   globals              0.14.0     2020-11-22 [1] CRAN (R 4.1.1)
   glue                 1.6.0      2021-12-17 [1] CRAN (R 4.1.2)
   gower                0.2.2      2020-06-23 [1] CRAN (R 4.1.1)
   GPfit                1.0-8      2019-02-08 [1] CRAN (R 4.1.1)
   greybox            * 1.0.2      2021-12-01 [1] CRAN (R 4.1.2)
   gridtext             0.1.4      2020-12-10 [1] CRAN (R 4.1.1)
   gtable               0.3.0      2019-03-25 [1] CRAN (R 4.1.1)
   hardhat              0.1.6      2021-07-14 [1] CRAN (R 4.1.1)
   haven                2.4.3      2021-08-04 [1] CRAN (R 4.1.1)
   here                 1.0.1      2020-12-13 [1] CRAN (R 4.1.1)
   hms                  1.1.1      2021-09-26 [1] CRAN (R 4.1.1)
   httr                 1.4.2      2020-07-20 [1] CRAN (R 4.1.1)
   imputeTS           * 3.2        2021-01-16 [1] CRAN (R 4.1.1)
   infer              * 1.0.0      2021-08-13 [1] CRAN (R 4.1.1)
   ipred                0.9-12     2021-09-15 [1] CRAN (R 4.1.1)
   iterators          * 1.0.13     2020-10-15 [1] CRAN (R 4.1.1)
   janitor              2.1.0      2021-01-05 [1] CRAN (R 4.1.1)
   jsonlite             1.7.3      2022-01-17 [1] CRAN (R 4.1.2)
   lamW                 2.1.1      2022-01-19 [1] CRAN (R 4.1.1)
   lattice              0.20-45    2021-09-22 [1] CRAN (R 4.1.1)
   lava                 1.6.10     2021-09-02 [1] CRAN (R 4.1.1)
   lhs                  1.1.3      2021-09-08 [1] CRAN (R 4.1.1)
   lifecycle            1.0.1      2021-09-24 [1] CRAN (R 4.1.1)
   listenv              0.8.0      2019-12-05 [1] CRAN (R 4.1.1)
   lmtest               0.9-39     2021-11-07 [1] CRAN (R 4.1.1)
   lubridate          * 1.8.0      2021-10-07 [1] CRAN (R 4.1.1)
   magrittr             2.0.1      2020-11-17 [1] CRAN (R 4.1.1)
   MAPA               * 2.0.4      2018-01-05 [1] CRAN (R 4.1.1)
   MASS                 7.3-55     2022-01-13 [1] CRAN (R 4.1.2)
   Matrix               1.4-0      2021-12-08 [1] CRAN (R 4.1.2)
   modeldata          * 0.1.1      2021-07-14 [1] CRAN (R 4.1.1)
   modelr               0.1.8      2020-05-19 [1] CRAN (R 4.1.1)
   modeltime          * 1.1.1      2022-01-12 [1] CRAN (R 4.1.2)
   modeltime.ensemble * 1.0.0      2021-10-19 [1] CRAN (R 4.1.1)
   modeltime.gluonts  * 0.3.1      2021-11-07 [1] Github (business-science/modeltime.gluonts@bcd3e8f)
   modeltime.resample * 0.2.0      2021-03-14 [1] CRAN (R 4.1.1)
   munsell              0.5.0      2018-06-12 [1] CRAN (R 4.1.1)
   nlme                 3.1-155    2022-01-13 [1] CRAN (R 4.1.2)
   nloptr               1.2.2.3    2021-11-02 [1] CRAN (R 4.1.2)
   nnet                 7.3-17     2022-01-13 [1] CRAN (R 4.1.2)
   parallelly           1.30.0     2021-12-17 [1] CRAN (R 4.1.2)
   parsnip            * 0.1.7      2021-07-21 [1] CRAN (R 4.1.1)
   pillar               1.6.4      2021-10-18 [1] CRAN (R 4.1.1)
   pkgconfig            2.0.3      2019-09-22 [1] CRAN (R 4.1.1)
   plyr                 1.8.6      2020-03-03 [1] CRAN (R 4.1.1)
   png                  0.1-7      2013-12-03 [1] CRAN (R 4.1.1)
   pracma               2.3.6      2021-12-07 [1] CRAN (R 4.1.2)
   pROC                 1.18.0     2021-09-03 [1] CRAN (R 4.1.1)
   processx             3.5.2      2021-04-30 [1] CRAN (R 4.1.1)
   prodlim              2019.11.13 2019-11-17 [1] CRAN (R 4.1.1)
   ps                   1.6.0      2021-02-28 [1] CRAN (R 4.1.1)
   purrr              * 0.3.4      2020-04-17 [1] CRAN (R 4.1.1)
   quadprog             1.5-8      2019-11-20 [1] CRAN (R 4.1.1)
   quantmod             0.4.18     2020-12-09 [1] CRAN (R 4.1.1)
   R6                   2.5.1      2021-08-19 [1] CRAN (R 4.1.1)
   rappdirs             0.3.3      2021-01-31 [1] CRAN (R 4.1.1)
   RColorBrewer       * 1.1-2      2014-12-07 [1] CRAN (R 4.1.1)
   Rcpp                 1.0.8      2022-01-13 [1] CRAN (R 4.1.2)
 D RcppParallel         5.1.5      2022-01-05 [1] CRAN (R 4.1.2)
   readr              * 2.1.1      2021-11-30 [1] CRAN (R 4.1.1)
   readxl               1.3.1      2019-03-13 [1] CRAN (R 4.1.1)
   recipes            * 0.1.17     2021-09-27 [1] CRAN (R 4.1.1)
   reprex               2.0.1      2021-08-05 [1] CRAN (R 4.1.1)
   reticulate           1.23       2022-01-14 [1] CRAN (R 4.1.2)
   rlang                0.4.12     2021-10-18 [1] CRAN (R 4.1.1)
   rpart                4.1-15     2019-04-12 [1] CRAN (R 4.1.1)
   rprojroot            2.0.2      2020-11-15 [1] CRAN (R 4.1.1)
   rsample            * 0.1.1      2021-11-08 [1] CRAN (R 4.1.1)
   rstudioapi           0.13       2020-11-12 [1] CRAN (R 4.1.1)
   rvest                1.0.2      2021-10-16 [1] CRAN (R 4.1.1)
   scales             * 1.1.1      2020-05-11 [1] CRAN (R 4.1.1)
   sessioninfo          1.2.2      2021-12-06 [1] CRAN (R 4.1.2)
   smooth             * 3.1.4      2021-12-01 [1] CRAN (R 4.1.2)
   snakecase            0.11.0     2019-05-25 [1] CRAN (R 4.1.1)
   StanHeaders          2.21.0-7   2020-12-17 [1] CRAN (R 4.1.1)
   statmod              1.4.36     2021-05-10 [1] CRAN (R 4.1.1)
   stinepack            1.4        2018-07-30 [1] CRAN (R 4.1.1)
   stringi              1.7.6      2021-11-29 [1] CRAN (R 4.1.1)
   stringr            * 1.4.0      2019-02-10 [1] CRAN (R 4.1.1)
   sumots               0.1.0      2022-01-19 [1] Github (vidarsumo/sumots@6bd71a3)
   survival             3.2-13     2021-08-24 [1] CRAN (R 4.1.1)
   tabnet             * 0.3.0      2021-10-11 [1] CRAN (R 4.1.2)
   texreg               1.37.5     2020-06-18 [1] CRAN (R 4.1.1)
   tibble             * 3.1.6      2021-11-07 [1] CRAN (R 4.1.1)
   tidymodels         * 0.1.4      2021-10-01 [1] CRAN (R 4.1.1)
   tidyr              * 1.1.4      2021-09-27 [1] CRAN (R 4.1.1)
   tidyselect           1.1.1      2021-04-30 [1] CRAN (R 4.1.1)
   tidyverse          * 1.3.1      2021-04-15 [1] CRAN (R 4.1.1)
   timeDate             3043.102   2018-02-21 [1] CRAN (R 4.1.1)
   timetk             * 2.6.2      2021-11-16 [1] CRAN (R 4.1.2)
   torch              * 0.6.0      2021-10-07 [1] CRAN (R 4.1.2)
   treesnip           * 0.1.0.9000 2021-10-17 [1] Github (curso-r/treesnip@60aade5)
   tseries              0.10-49    2021-11-16 [1] CRAN (R 4.1.2)
   tsintermittent     * 1.9        2016-03-10 [1] CRAN (R 4.1.1)
   TTR                  0.24.3     2021-12-12 [1] CRAN (R 4.1.1)
   tune               * 0.1.6      2021-07-21 [1] CRAN (R 4.1.1)
   tzdb                 0.2.0      2021-10-27 [1] CRAN (R 4.1.1)
   urca                 1.3-0      2016-09-06 [1] CRAN (R 4.1.1)
   utf8                 1.2.2      2021-07-24 [1] CRAN (R 4.1.1)
   vctrs                0.3.8      2021-04-29 [1] CRAN (R 4.1.1)
   withr                2.4.3      2021-11-30 [1] CRAN (R 4.1.1)
   workflows          * 0.2.4      2021-10-12 [1] CRAN (R 4.1.1)
   workflowsets       * 0.1.0      2021-07-22 [1] CRAN (R 4.1.1)
   xml2                 1.3.3      2021-11-30 [1] CRAN (R 4.1.2)
   xts                  0.12.1     2020-09-09 [1] CRAN (R 4.1.1)
   yardstick          * 0.0.9      2021-11-22 [1] CRAN (R 4.1.1)
   zoo                  1.8-9      2021-03-09 [1] CRAN (R 4.1.1)

 [1] C:/Program Files/R/R-4.1.1/library

 D -- DLL MD5 mismatch, broken installation.

- Python configuration -------------------------------------------------------------------------------------------------------------------------------------------------------------------
 python:         C:/Miniconda/envs/r-gluonts/python.exe
 libpython:      C:/Miniconda/envs/r-gluonts/python37.dll
 pythonhome:     C:/Miniconda/envs/r-gluonts
 version:        3.7.1 | packaged by conda-forge | (default, Mar 13 2019, 13:32:59) [MSC v.1900 64 bit (AMD64)]
 Architecture:   64bit
 numpy:          C:/Miniconda/envs/r-gluonts/Lib/site-packages/numpy
 numpy_version:  1.21.4
 numpy:          C:\MINICO~1\envs\R-GLUO~1\lib\site-packages\numpy\__init__.p

 NOTE: Python version was forced by use_python function

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
vidarsumo commented 2 years ago

cuda_is_available() gives TRUE and torch_tensor(1, device = "cuda") gives

torch_tensor
 1
[ CUDAFloatType{1} ]
vidarsumo commented 2 years ago

I tried few things, including installing CUDA 11.1 (instead of 10.2) and cuDNN 8.1.1. I followed this instructions: https://docs.nvidia.com/cuda/archive/11.1.1/cuda-quick-start-guide/index.html.

I then installed the dev version and forced 11.1 version of CUDA.

remotes::install_github("mlverse/torch")

Sys.setenv(CUDA="11.1")

library(torch)
trying URL 'https://download.pytorch.org/libtorch/cu111/libtorch-win-shared-with-deps-1.9.1%2Bcu111.zip'
Content type 'application/zip' length 3058035948 bytes (2916.4 MB)
downloaded 2916.4 MB

trying URL 'https://storage.googleapis.com/torch-lantern-builds/refs/heads/master/latest/Windows-gpu-111.zip'
Content type 'application/zip' length 1780277 bytes (1.7 MB)
downloaded 1.7 MB

I have NVIDIA Tesla P100 and it doesn't seem to be in use when I try to train e.g. TabNet.

Running this code and the CPU goes to ~90% but the GPU says at 0-1%. It's just the same speed as on my laptop.

library(torch)
library(tabnet)
library(tidymodels)
library(tidyverse)
library(modeldata)

data(credit_data)

credit_data <- credit_data %>% drop_na()

model_spec_tabnet_tune <- tabnet(
  mode   = "regression",
  epochs = 20
  ) %>%
  set_engine("torch", verbose = TRUE) %>% 
  fit(Price ~ ., data = credit_data)

Does anyone have any idea why the GPU is not in use?

Note that the GPU is working fine for other algorithms like XGBoost.

cregouby commented 2 years ago

Hello @vidarsumo,

The parsnip tabnet() function currently do not pass any device= argument to tabnet_config() and thus, the device used shall default to "auto". https://github.com/mlverse/tabnet/blob/f4f815f43e017ab3b6169e730b74037397041033/R/parsnip.R#L241 device="auto" choose "cuda" when cuda is available, so you code should use the GPU as soon as it is correctly setup.

Could you confirm that the GPU is used as expected out of a workflow() in using the native tabnet_fit(Price ~ ., data = credit_data, device="cuda") function for your model, where you can explicit device="cuda" as parameter ?

If not, then your infrastructure setup need some fix, if yes, I'll have a deeper look at the parsnip machinery.

Hope it helps,

vidarsumo commented 2 years ago

Hi @cregouby ,

If I use the native tabnet_fit() the GPU load is constantly at ~2% and sometimes jumps to 20-30% for a fraction of a second and then back down to ~2% or even 0%.

How can I find out if something on my side needs to be fixed? I followed every step in the the setup guide (https://cran.r-project.org/web/packages/torch/vignettes/installation.html) and cuda_is_available() gives TRUE.

cregouby commented 2 years ago

Ok, so that means that the GPU setup is correct and correctly used by the native tabnet_fit(). But

  1. your first issue is you have a far too low value for batch_size. In such high-end setup, you should feed the GPU with much more data !, and be able to configure it to 500e3 or 5e6. For performance reason, you should align the virtual_batch_size accordingly (like in #79 ). The limit is the GPU RAM, so rule of thumb is keep increasing it until you get the CUDA OOM error...
  2. you hit the #78 bug. We need time to solve it. The typical pattern of GPU waiting for single threat data-preprocessing and computing the torch tensor too fast, so waiting again...

After fixing the 1. maybe you will see the GPU hits even with the workflow tabnet() method.

Hope it helps

vidarsumo commented 2 years ago

Ok so I tried to set batch_size to 2^20 and virtual_batch_size to 2^18. The GPU went up to 100% for 1-2 sec while the first epoch finished and then GPU load went down to 0% for few minutes until the GPU went back to 100% for 1-2 seconds and then down again to 0% for few minutes.

So the GPU is working for sure.

sametsoekel commented 2 years ago

Hi, please can you share how you achieved to use %100 of your GPU with a reproducible example ? I followed all the steps in this issue and updated batch_size and virtual_batch_size accordingly but still using ~%2 - %0 of my GPU.