business-science / modeltime.gluonts

GluonTS Deep Learning with Modeltime
https://business-science.github.io/modeltime.gluonts/
Other
38 stars 11 forks source link

Error using GPU with latest version modeltime.gluonts with gluonts 0.8.0 and torch #36

Open vidarsumo opened 2 years ago

vidarsumo commented 2 years ago

_Note that this error only occurs when using the environment created below. There is no error when using the default r-gluonts environment, both gluontsdeepar and torch work using r-gluonts.

What I did I updated mxnet to 1.7 following this: https://ts.gluon.ai/install.html I then installed the latest version (dev) of modeltime.gluonts. I did fresh install and included pytorch. I created new environment following this https://business-science.github.io/modeltime.gluonts/articles/using-gpus.html:

reticulate::py_install(
    envname  = "my_gluonts_env_08",
    python_version = "3.7.1",
    packages = c(
        "mxnet-cu100",

        "gluonts==0.8.0",
        "pandas",
        "numpy",
        "pathlib"
    ),
    method = "conda",
    pip = TRUE
)

There is no error when training the model usign gluonts_deepar engine.

library(dplyr)
my_env <- reticulate::conda_list() %>%
  filter(name == "my_gluonts_env_08") %>%
  pull(python)

Sys.setenv(GLUONTS_PYTHON = my_env)

library(modeltime.gluonts)
library(tidyverse)
library(tidymodels)

m750_train <- m750[1:(306-24), ]
m750_test  <- m750 %>% filter(!date %in% m750_train$date)

model_fit_deepar <- deep_ar(
  id                    = "id",
  freq                  = "M",
  prediction_length     = 24,
  lookback_length       = 36,
  epochs                = 1, 
  num_batches_per_epoch = 500,
  learn_rate            = 0.001,
  num_layers            = 3,
  num_cells             = 80,
  dropout               = 0.10
) %>%
  set_engine("gluonts_deepar") %>%
  fit(value ~ date + id, m750_train)

But when I want to make a forecast I get an error:

model_fit_deepar %>% 
  modeltime_table() %>% 
  modeltime_forecast(new_data = m750_test, actual_data = m750)

  Error: Problem occurred during prediction. Error in py_iter_next(it, completed): MXNetError: vector<T> too long

Detailed traceback:
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\mx\model\predictor.py", line 171, in predict
    num_samples=num_samples,
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\model\forecast_generator.py", line 174, in __call__
    outputs = predict_to_numpy(prediction_net, inputs)
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\functools.py", line 824, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\mx\model\predictor.py", line 50, in _
    return prediction_net(*inputs).asnumpy()
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\mxnet\gluon\block.py", line 548, in __call__
    out = self.forward(*args)
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\mxnet\gluon\block.py", line 925, in forward
    return self.hybrid_forward(ndarray, x, *args, **params)
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\model\deepar\_network.py", line 1162, in hybrid_forward
    begin_states=state,
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\model\deepar\_network.py", line 1093, in sampling_decoder
    new_samples = distr.sample(dtype=self.dtype)
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\mx\distribution\transformed_distribution.py", line 135, in sample
    num_samples=num_samples, dtype=dtype
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\mx\distribution\student_t.py", line 117, in sample
    num_samples=num_samples,
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\mx\distribution\distribution.py", line 416, in _sample_multiple
    samples = sample_func(*args_expanded, **kwargs_expanded)
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\gluonts\mx\distribution\student_t.py", line 105, in s
    alpha=nu / 2.0, beta=2.0 / (nu * F.square(sigma)), dtype=dtype
  File "<string>", line 68, in sample_gamma
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\mxnet\_ctypes\ndarray.py", line 92, in _imperative_invoke
    ctypes.byref(out_stypes)))
  File "C:\Miniconda\envs\my_gluonts_env_08\lib\site-packages\mxnet\base.py", line 253, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))

Error: Problem with `filter()` input `..1`.
i Input `..1` is `.model_desc == "ACTUAL" | .key == "prediction"`.
x object '.key' not found
Run `rlang::last_error()` to see where the error occurred.
In addition: Warning message:
Unknown or uninitialised column: `.key`. 

Same with accuracy:

model_fit_deepar %>% 
  modeltime_table() %>% 
  modeltime_calibrate(m750_test)

  -- Model Calibration Failure Report ------------------------
# A tibble: 1 x 4
  .model_id .model   .model_desc .nested.col
      <int> <list>   <chr>       <lgl>      
1         1 <fit[+]> DEEPAR      NA         
All models failed Modeltime Calibration:
- Model 1: Failed Calibration.

Potential Solution: Use `modeltime_calibrate(quiet = FALSE)` AND Check the Error/Warning Messages for clues as to why your model(s) failed calibration.
-- End Model Calibration Failure Report --------------------

Error: All models failed Modeltime Calibration.

Then I get a different error when I switch to the torch engine:

model_fit_deepar <- deep_ar(
  id                    = "id",
  freq                  = "M",
  prediction_length     = 24,
  lookback_length       = 36,
  epochs                = 1, 
  num_batches_per_epoch = 500,
  learn_rate            = 0.001,
  num_layers            = 3,
  num_cells             = 80,
  dropout               = 0.10
) %>%
  set_engine("torch") %>%
  fit(value ~ date + id, m750_train)

  Error in py_get_attr_impl(x, name, silent) : 
  AttributeError: module 'gluonts' has no attribute 'torch'
Timing stopped at: 0 0 0.06

Session info

sessioninfo::session_info()
- Session info ------------------------------------------------------------------------------------------------------------------------------------------------------------------
 setting  value                       
 version  R version 4.0.3 (2020-10-10)
 os       Windows Server x64          
 system   x86_64, mingw32             
 ui       RStudio                     
 language (EN)                        
 collate  English_United States.1252  
 ctype    English_United States.1252  
 tz       GMT                         
 date     2021-08-16                  

- Packages ----------------------------------------------------------------------------------------------------------------------------------------------------------------------
 ! package           * version    date       lib source                                             
   assertthat          0.2.1      2019-03-21 [1] CRAN (R 4.0.3)                                     
   backports           1.2.1      2020-12-09 [1] CRAN (R 4.0.3)                                     
   broom             * 0.7.9      2021-07-27 [1] CRAN (R 4.0.5)                                     
   cellranger          1.1.0      2016-07-27 [1] CRAN (R 4.0.3)                                     
   class               7.3-19     2021-05-03 [1] CRAN (R 4.0.5)                                     
   cli                 3.0.1      2021-07-17 [1] CRAN (R 4.0.5)                                     
   codetools           0.2-18     2020-11-04 [1] CRAN (R 4.0.3)                                     
   colorspace          2.0-2      2021-06-24 [1] CRAN (R 4.0.5)                                     
   crayon              1.4.1      2021-02-08 [1] CRAN (R 4.0.3)                                     
   DBI                 1.1.1      2021-01-15 [1] CRAN (R 4.0.5)                                     
   dbplyr              2.1.1      2021-04-06 [1] CRAN (R 4.0.5)                                     
   dials             * 0.0.9      2020-09-16 [1] CRAN (R 4.0.3)                                     
   DiceDesign          1.9        2021-02-13 [1] CRAN (R 4.0.4)                                     
   digest              0.6.27     2020-10-24 [1] CRAN (R 4.0.3)                                     
   dplyr             * 1.0.7      2021-06-18 [1] CRAN (R 4.0.5)                                     
   ellipsis            0.3.2      2021-04-29 [1] CRAN (R 4.0.5)                                     
   fansi               0.5.0      2021-05-25 [1] CRAN (R 4.0.5)                                     
   forcats           * 0.5.1      2021-01-27 [1] CRAN (R 4.0.3)                                     
   foreach             1.5.1      2020-10-15 [1] CRAN (R 4.0.3)                                     
   fs                  1.5.0      2020-07-31 [1] CRAN (R 4.0.3)                                     
   furrr               0.2.3      2021-06-25 [1] CRAN (R 4.0.5)                                     
   future              1.21.0     2020-12-10 [1] CRAN (R 4.0.3)                                     
   generics            0.1.0      2020-10-31 [1] CRAN (R 4.0.3)                                     
   ggplot2           * 3.3.5      2021-06-25 [1] CRAN (R 4.0.5)                                     
   globals             0.14.0     2020-11-22 [1] CRAN (R 4.0.3)                                     
   glue                1.4.2      2020-08-27 [1] CRAN (R 4.0.3)                                     
   gower               0.2.2      2020-06-23 [1] CRAN (R 4.0.3)                                     
   GPfit               1.0-8      2019-02-08 [1] CRAN (R 4.0.3)                                     
   gtable              0.3.0      2019-03-25 [1] CRAN (R 4.0.3)                                     
   hardhat             0.1.6      2021-07-14 [1] CRAN (R 4.0.5)                                     
   haven               2.4.3      2021-08-04 [1] CRAN (R 4.0.5)                                     
   hms                 1.1.0      2021-05-17 [1] CRAN (R 4.0.5)                                     
   httr                1.4.2      2020-07-20 [1] CRAN (R 4.0.3)                                     
   infer             * 0.5.4      2021-01-13 [1] CRAN (R 4.0.3)                                     
   ipred               0.9-11     2021-03-12 [1] CRAN (R 4.0.4)                                     
   iterators           1.0.13     2020-10-15 [1] CRAN (R 4.0.3)                                     
   jsonlite            1.7.2      2020-12-09 [1] CRAN (R 4.0.3)                                     
   lattice             0.20-44    2021-05-02 [1] CRAN (R 4.0.5)                                     
   lava                1.6.9      2021-03-11 [1] CRAN (R 4.0.4)                                     
   lhs                 1.1.1      2020-10-05 [1] CRAN (R 4.0.3)                                     
   lifecycle           1.0.0      2021-02-15 [1] CRAN (R 4.0.3)                                     
   listenv             0.8.0      2019-12-05 [1] CRAN (R 4.0.3)                                     
   lubridate           1.7.10     2021-02-26 [1] CRAN (R 4.0.4)                                     
   magrittr            2.0.1      2020-11-17 [1] CRAN (R 4.0.3)                                     
   MASS                7.3-54     2021-05-03 [1] CRAN (R 4.0.5)                                     
   Matrix              1.3-4      2021-06-01 [1] CRAN (R 4.0.5)                                     
   modeldata         * 0.1.1      2021-07-14 [1] CRAN (R 4.0.5)                                     
   modelr              0.1.8      2020-05-19 [1] CRAN (R 4.0.3)                                     
   modeltime         * 0.7.0      2021-07-16 [1] CRAN (R 4.0.5)                                     
   modeltime.gluonts * 0.3.1      2021-08-13 [1] Github (business-science/modeltime.gluonts@bcd3e8f)
   munsell             0.5.0      2018-06-12 [1] CRAN (R 4.0.3)                                     
   nnet                7.3-16     2021-05-03 [1] CRAN (R 4.0.5)                                     
   parallelly          1.27.0     2021-07-19 [1] CRAN (R 4.0.5)                                     
   parsnip           * 0.1.7      2021-07-21 [1] CRAN (R 4.0.5)                                     
   pillar              1.6.2      2021-07-29 [1] CRAN (R 4.0.5)                                     
   pkgconfig           2.0.3      2019-09-22 [1] CRAN (R 4.0.3)                                     
   plyr                1.8.6      2020-03-03 [1] CRAN (R 4.0.3)                                     
   png                 0.1-7      2013-12-03 [1] CRAN (R 4.0.3)                                     
   pROC                1.17.0.1   2021-01-13 [1] CRAN (R 4.0.3)                                     
   prodlim             2019.11.13 2019-11-17 [1] CRAN (R 4.0.3)                                     
   purrr             * 0.3.4      2020-04-17 [1] CRAN (R 4.0.3)                                     
   R6                  2.5.0      2020-10-28 [1] CRAN (R 4.0.3)                                     
   rappdirs            0.3.3      2021-01-31 [1] CRAN (R 4.0.3)                                     
   Rcpp                1.0.7      2021-07-07 [1] CRAN (R 4.0.5)                                     
 D RcppParallel        5.1.4      2021-05-04 [1] CRAN (R 4.0.5)                                     
   readr             * 2.0.0      2021-07-20 [1] CRAN (R 4.0.5)                                     
   readxl              1.3.1      2019-03-13 [1] CRAN (R 4.0.3)                                     
   recipes           * 0.1.16     2021-04-16 [1] CRAN (R 4.0.3)                                     
   reprex              2.0.1      2021-08-05 [1] CRAN (R 4.0.5)                                     
   reticulate          1.20       2021-05-03 [1] CRAN (R 4.0.5)                                     
   rlang               0.4.11     2021-04-30 [1] CRAN (R 4.0.5)                                     
   rpart               4.1-15     2019-04-12 [1] CRAN (R 4.0.3)                                     
   rsample           * 0.1.0      2021-05-08 [1] CRAN (R 4.0.3)                                     
   rstudioapi          0.13       2020-11-12 [1] CRAN (R 4.0.3)                                     
   rvest               1.0.1      2021-07-26 [1] CRAN (R 4.0.5)                                     
   scales            * 1.1.1      2020-05-11 [1] CRAN (R 4.0.3)                                     
   sessioninfo         1.1.1      2018-11-05 [1] CRAN (R 4.0.3)                                     
   StanHeaders         2.21.0-7   2020-12-17 [1] CRAN (R 4.0.3)                                     
   stringi             1.7.3      2021-07-16 [1] CRAN (R 4.0.5)                                     
   stringr           * 1.4.0      2019-02-10 [1] CRAN (R 4.0.3)                                     
   survival            3.2-11     2021-04-26 [1] CRAN (R 4.0.5)                                     
   tibble            * 3.1.3      2021-07-23 [1] CRAN (R 4.0.5)                                     
   tidymodels        * 0.1.3      2021-04-19 [1] CRAN (R 4.0.5)                                     
   tidyr             * 1.1.3      2021-03-03 [1] CRAN (R 4.0.4)                                     
   tidyselect          1.1.1      2021-04-30 [1] CRAN (R 4.0.5)                                     
   tidyverse         * 1.3.1      2021-04-15 [1] CRAN (R 4.0.3)                                     
   timeDate            3043.102   2018-02-21 [1] CRAN (R 4.0.3)                                     
   timetk              2.6.1      2021-01-18 [1] CRAN (R 4.0.3)                                     
   tinytex             0.33       2021-08-05 [1] CRAN (R 4.0.5)                                     
   tune              * 0.1.6      2021-07-21 [1] CRAN (R 4.0.5)                                     
   tzdb                0.1.2      2021-07-20 [1] CRAN (R 4.0.5)                                     
   utf8                1.2.2      2021-07-24 [1] CRAN (R 4.0.5)                                     
   vctrs               0.3.8      2021-04-29 [1] CRAN (R 4.0.5)                                     
   withr               2.4.2      2021-04-18 [1] CRAN (R 4.0.3)                                     
   workflows         * 0.2.3      2021-07-16 [1] CRAN (R 4.0.5)                                     
   workflowsets      * 0.1.0      2021-07-22 [1] CRAN (R 4.0.5)                                     
   xfun                0.25       2021-08-06 [1] CRAN (R 4.0.5)                                     
   xml2                1.3.2      2020-04-23 [1] CRAN (R 4.0.3)                                     
   xts                 0.12.1     2020-09-09 [1] CRAN (R 4.0.3)                                     
   yardstick         * 0.0.8      2021-03-28 [1] CRAN (R 4.0.5)                                     
   zoo                 1.8-9      2021-03-09 [1] CRAN (R 4.0.4)