business-science / modeltime.gluonts

GluonTS Deep Learning with Modeltime
https://business-science.github.io/modeltime.gluonts/
Other
39 stars 9 forks source link

Error in nbeats() Using conical example #8

Closed spsanderson closed 4 years ago

spsanderson commented 4 years ago

df_grouped_tbl.zip

I am using the conical example in the description file for the nbeats model. My data is attached. Here is a glimpse of the data which is similar to what you have in the example.

df_grouped_tbl
# A tibble: 70,757 x 3
   id      date                 value
   <fct>   <dttm>               <dbl>
 1 Medical 2011-09-13 00:00:00 -2.34 
 2 Medical 2011-09-14 00:00:00 -2.47 
 3 Medical 2011-09-15 00:00:00 -2.47 
 4 Medical 2011-09-16 00:00:00 -2.47 
 5 Medical 2011-09-17 00:00:00 -2.47 
 6 Medical 2011-09-18 00:00:00 -2.47 
 7 Medical 2011-09-19 00:00:00 -2.47 
 8 Medical 2011-09-20 00:00:00 -2.47 
 9 Medical 2011-09-21 00:00:00 -1.83 
10 Medical 2011-09-22 00:00:00 -0.674
# ... with 70,747 more rows

Here is the code I run which is the same:

forcast_horizon <-  30

new_data <- df_grouped_tbl %>%
    group_by(id) %>%
    future_frame(.date_var = date, .length_out = forcast_horizon) %>%
    ungroup()

model_fit_nbeats_ensemble <- nbeats(
    id                    = "id",
    freq                  = "D",
    prediction_length     = forcast_horizon,
    lookback_length       = c(forcast_horizon, 4*forcast_horizon),
    epochs                = 5,
    num_batches_per_epoch = 15,
    batch_size            = 1 
) %>%
    set_engine("gluonts_nbeats") %>%
    fit(value ~ date + id, df_grouped_tbl)

Here is the python error that comes back:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValidationError: 1 validation error for NBEATSEstimatorModel
context_length
  value is not a valid integer (type=type_error.integer)

Detailed traceback: 
  File "C:\Users\Steve\AppData\Local\R-MINI~1\envs\R-GLUO~1\lib\site-packages\gluonts\core\component.py", line 418, in init_wrapper
    model = PydanticModel(**{**nmargs, **kwargs})
  File "pydantic\main.py", line 362, in pydantic.main.BaseModel.__init__
Timing stopped at: 0.11 0 0.11

Changing the value to an integer and not using the standardize_vec() function by instead using n() %>% as.integer() yields the same exact error.

mdancho84 commented 4 years ago

Yes, this is because you are using Standard N-Beats, and you need Ensemble N-Beats, which accepts multiple lookback_lengths. I've just updated modeltime.gluonts to respond with an error telling you to switch engines. I've also updated the example on the website to use "gluonts_nbeats_ensemble", which is correct.

Here is working code.


library(modeltime.gluonts)
library(tidymodels)
library(tidyverse)
library(timetk)

data <- read_rds("~/Downloads/df_grouped_tbl 2")

data %>%
    # filter(id %in% unique(id)[1:3]) %>%
    plot_time_series(date, value, .facet_vars = id, .facet_ncol = 3)

forcast_horizon <-  30

new_data <- data %>%
    # filter(id %in% unique(id)[1:3]) %>%
    group_by(id) %>%
    future_frame(.date_var = date, .length_out = forcast_horizon) %>%
    ungroup()

model_fit_nbeats_ensemble <- nbeats(
    id                    = "id",
    freq                  = "D",
    prediction_length     = forcast_horizon,
    lookback_length       = c(forcast_horizon, 4*forcast_horizon),
    epochs                = 5,
    num_batches_per_epoch = 20,
    bagging_size          = 1 
) %>%
    set_engine("gluonts_nbeats_ensemble") %>%
    fit(value ~ date + id, data)

preds <- modeltime_table(
    model_fit_nbeats_ensemble
) %>%
    modeltime_forecast(new_data, actual_data = data, keep_data = TRUE)

preds %>%
    filter(id %in% unique(id)[1:6]) %>%
    group_by(id) %>%
    plot_modeltime_forecast(
        .facet_ncol = 3
    )
spsanderson commented 4 years ago

Working now.

Is there a way to inverse the stardardized_vec to get the non-standardized predictions? Can we simply run standardized_inv_vec on the predictions?

mdancho84 commented 4 years ago

The best way is to use a recipe. This preserves the transformation value. recipes::step_normalize(), which actually performs center/scaling.