business-science / modeltime.gluonts

GluonTS Deep Learning with Modeltime
https://business-science.github.io/modeltime.gluonts/
Other
39 stars 9 forks source link

Unexpected Behavior using different dataset then example #7

Closed spsanderson closed 3 years ago

spsanderson commented 3 years ago

data_tbl.zip

I am using the dataset attached to try and play around with the package in the base example provided in the readme.

Session Info:

> sessionInfo()
R version 4.0.2 (2020-06-22)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 18363)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] timetk_2.5.0            healthyR.data_1.0.0     forcats_0.5.0           stringr_1.4.0          
 [5] readr_1.4.0             tidyverse_1.3.0         yardstick_0.0.7         workflows_0.2.1        
 [9] tune_0.1.2              tidyr_1.1.2             tibble_3.0.4            rsample_0.0.8          
[13] recipes_0.1.15          purrr_0.3.4             parsnip_0.1.4           modeldata_0.1.0        
[17] infer_0.5.3             ggplot2_3.3.2           dplyr_1.0.2             dials_0.0.9            
[21] scales_1.1.1            broom_0.7.2             tidymodels_0.1.1        modeltime.gluonts_0.1.0
[25] modeltime_0.3.1        

loaded via a namespace (and not attached):
 [1] fs_1.5.0             xts_0.12.1           lubridate_1.7.9.2    DiceDesign_1.8-1     httr_1.4.2          
 [6] tools_4.0.2          backports_1.2.0      R6_2.5.0             rpart_4.1-15         DBI_1.1.0           
[11] colorspace_2.0-0     nnet_7.3-14          withr_2.3.0          tidyselect_1.1.0     compiler_4.0.2      
[16] cli_2.2.0            rvest_0.3.6          xml2_1.3.2           rappdirs_0.3.1       digest_0.6.27       
[21] StanHeaders_2.21.0-6 rmarkdown_2.5        pkgconfig_2.0.3      htmltools_0.5.0      parallelly_1.21.0   
[26] lhs_1.1.1            dbplyr_2.0.0         rlang_0.4.8          readxl_1.3.1         rstudioapi_0.13     
[31] generics_0.1.0       zoo_1.8-8            jsonlite_1.7.1       magrittr_2.0.1       Matrix_1.2-18       
[36] Rcpp_1.0.5           munsell_0.5.0        fansi_0.4.1          GPfit_1.0-8          reticulate_1.18     
[41] lifecycle_0.2.0      furrr_0.2.1          stringi_1.5.3        pROC_1.16.2          yaml_2.2.1          
[46] MASS_7.3-51.6        plyr_1.8.6           grid_4.0.2           parallel_4.0.2       listenv_0.8.0       
[51] crayon_1.3.4         lattice_0.20-41      haven_2.3.1          splines_4.0.2        hms_0.5.3           
[56] knitr_1.30           pillar_1.4.7         codetools_0.2-16     reprex_0.3.0         glue_1.4.2          
[61] packrat_0.5.0        evaluate_0.14        RcppParallel_5.0.2   modelr_0.1.8         vctrs_0.3.5         
[66] foreach_1.5.1        cellranger_1.1.0     gtable_0.3.0         future_1.20.1        assertthat_0.2.1    
[71] xfun_0.19            gower_0.2.2          prodlim_2019.11.13   class_7.3-17         survival_3.1-12     
[76] timeDate_3043.102    iterators_1.0.13     lava_1.6.8.1         globals_0.13.1       ellipsis_0.3.1      
[81] ipred_0.9-9  

I run the following:

library(modeltime.gluonts)
library(tidymodels)
library(tidyverse)
library(healthyR.data)
library(timetk)

df_tbl <- healthyR_data %>%
    filter(ip_op_flag == "I") %>%
    mutate(id = as.factor("IP")) %>%
    select(id, visit_end_date_time, length_of_stay) %>%
    arrange(visit_end_date_time) %>%
    filter_by_time(.date_var = visit_end_date_time, .start_date = "2012") %>%
    summarise_by_time(
        .date_var = visit_end_date_time
        , .by = "month"
        , value = round(mean(length_of_stay), 2)
    ) %>%
    mutate(id = "IP") %>%
    select(id, visit_end_date_time, value) %>%
    set_names("id", "date", "value")

dfsplits <- initial_time_split(df_tbl, .8)

# Fit a GluonTS DeepAR Model
model_fit_deepar <- deep_ar(
    id                    = "id",
    freq                  = "M",
    prediction_length     = 12,
    lookback_length       = 8,
    epochs                = 10, 
    num_batches_per_epoch = 50,
    learn_rate            = 0.001,
    num_layers            = 2,
    dropout               = 0.10
) %>%
    set_engine("gluonts_deepar") %>%
    fit(value ~ ., training(dfsplits))

# Forecast with 95% Confidence Interval
modeltime_table(
    model_fit_deepar
) %>%
    modeltime_calibrate(new_data = testing(dfsplits)) %>%
    modeltime_forecast(
        new_data      = testing(dfsplits),
        actual_data   = df_tbl,
        conf_interval = 0.95
    ) %>%
    plot_modeltime_forecast(.interactive = FALSE)

And get a result that shows that the prediction does not go to the end of the data as expected along with no confidence intervals.

image

Here is the resulting data that comes out:

> # Forecast with 95% Confidence Interval
> modeltime_table(
+     model_fit_deepar
+ ) %>%
+     modeltime_calibrate(new_data = testing(dfsplits)) %>%
+     modeltime_forecast(
+         new_data      = testing(dfsplits),
+         actual_data   = df_tbl,
+         conf_interval = 0.95
+     ) %>% dput()
WARNING:root:You have set `num_workers` to a non zero value, however, currently multiprocessing is not supported on windows and therefore`num_workers will be set to 0.
WARNING:root:You have set `num_workers` to a non zero value, however, currently multiprocessing is not supported on windows and therefore`num_workers will be set to 0.
structure(list(.model_id = c(NA, NA, NA, NA, NA, NA, NA, NA, 
NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, 
NA, NA, NA, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L), .model_desc = c("ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", "ACTUAL", 
"ACTUAL", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", 
"DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", 
"DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", "DEEPAR", 
"DEEPAR", "DEEPAR"), .key = structure(c(1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L), .Label = c("actual", 
"prediction"), class = "factor"), .index = structure(c(1325376000, 
1328054400, 1330560000, 1333238400, 1335830400, 1338508800, 1341100800, 
1343779200, 1346457600, 1349049600, 1351728000, 1354320000, 1356998400, 
1359676800, 1362096000, 1364774400, 1367366400, 1370044800, 1372636800, 
1375315200, 1377993600, 1380585600, 1383264000, 1385856000, 1388534400, 
1391212800, 1393632000, 1396310400, 1398902400, 1401580800, 1404172800, 
1406851200, 1409529600, 1412121600, 1414800000, 1417392000, 1420070400, 
1422748800, 1425168000, 1427846400, 1430438400, 1433116800, 1435708800, 
1438387200, 1441065600, 1443657600, 1446336000, 1448928000, 1451606400, 
1454284800, 1456790400, 1459468800, 1462060800, 1464739200, 1467331200, 
1470009600, 1472688000, 1475280000, 1477958400, 1480550400, 1483228800, 
1485907200, 1488326400, 1491004800, 1493596800, 1496275200, 1498867200, 
1501545600, 1504224000, 1506816000, 1509494400, 1512086400, 1514764800, 
1517443200, 1519862400, 1522540800, 1525132800, 1527811200, 1530403200, 
1533081600, 1535760000, 1538352000, 1541030400, 1543622400, 1546300800, 
1548979200, 1551398400, 1554076800, 1556668800, 1559347200, 1561939200, 
1564617600, 1567296000, 1569888000, 1572566400, 1575158400, 1577836800, 
1580515200, 1583020800, 1585699200, 1588291200, 1590969600, 1593561600, 
1596240000, 1598918400, 1601510400, 1604188800, 1548979200, 1551398400, 
1554076800, 1556668800, 1559347200, 1561939200, 1564617600, 1567296000, 
1569888000, 1572566400, 1575158400, 1577836800, 1580515200, 1583020800, 
1585699200, 1588291200, 1590969600, 1593561600, 1596240000, 1598918400, 
1601510400, 1604188800), tzone = "UTC", class = c("POSIXct", 
"POSIXt")), .value = c(5.82, 6.1, 5.57, 5.13, 5.38, 5.01, 5.43, 
5.16, 5.4, 5.94, 5.99, 5.75, 6.19, 6.36, 6.35, 6.09, 5.83, 5.89, 
6.21, 5.92, 6.03, 6.25, 6.3, 6.86, 6.5, 6.11, 6.38, 6.15, 5.8, 
6.06, 5.9, 6.28, 6.01, 6.32, 6.08, 6.35, 6.79, 7.44, 6.38, 5.98, 
6.12, 6.04, 5.9, 6.19, 6.15, 6.47, 6.15, 6.16, 6.56, 6.58, 6.6, 
6.2, 6.91, 6.25, 6.4, 6.9, 6.45, 6.98, 6.79, 7.07, 7.23, 6.42, 
6.99, 5.92, 6.1, 6.34, 6.1, 5.96, 5.98, 6.63, 6, 5.68, 6.49, 
7.12, 6.26, 6.58, 6.33, 5.98, 5.58, 6.28, 5.99, 6.33, 6.11, 6.04, 
6.36, 6.34, 6.8, 6.65, 6.06, 5.88, 6.09, 6.03, 6.3, 6.38, 6.14, 
6.81, 6.81, 6.48, 6.68, 7.6, 8.47, 6.32, 6.39, 6.03, 5.91, 6.09, 
5.59, 6.53306770324707, 6.10734748840332, 6.02545738220215, 6.08266544342041, 
6.00361347198486, 5.81131219863892, 5.9768648147583, 6.05223703384399, 
5.97466373443604, 6.11306142807007, 6.21186113357544, 6.56806468963623, 
NA, NA, NA, NA, NA, NA, NA, NA, NA, NA), .conf_lo = c(NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_), .conf_hi = c(NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, 
NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_, NA_real_
)), row.names = c(NA, -129L), class = c("tbl_df", "tbl", "data.frame"
))
mdancho84 commented 3 years ago

Hi Steven,

Thanks for this.

Regarding your prediction, I see you have inconsistencies in your prediction_length and your definition for your testing(splits), which is defined by a proportion of 0.8.

The prediction length should be the same as the number of time stamps in your forecast data (test data). Count the time stamps and use that value as the prediction length. You will see that your forecast goes the full length.

Regarding the CI, I'd start by fixing the prediction length and go from there.

spsanderson commented 3 years ago

Ok I think I know what you mean I will revisit this soon on my pc

Sent from my iPhone

On Nov 20, 2020, at 8:18 AM, Matt Dancho notifications@github.com wrote:

 Hi Steven,

Thanks for this.

Regarding your prediction, I see you have inconsistencies in your prediction_length and your definition for your testing(splits), which is defined by a proportion of 0.8.

The prediction length should be the same as the number of time stamps in your training set. Count the time stamps and use that value as the prediction length. You will see that your forecast goes the full length.

Regarding the CI, I'd start by fixing the prediction length and go from there.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

spsanderson commented 3 years ago

Ok, I got it now. Is there a way to auto-select the prediction length from the split object? Maybe something like:

dfsplits[3] %>% 
as_tibble() %>% 
nrow()

Or maybe a little nicer:

nrow(training(dfsplits))
mdancho84 commented 3 years ago

Yes, this is how I would do it for a single time series group. For multiple groups, you will need to select based on a horizon.