mlverse / tft

R implementation of Temporal Fusion Transformers
https://mlverse.github.io/tft/
Other
25 stars 9 forks source link

Minimum number of observation in relation to lookback #36

Closed vidarsumo closed 1 year ago

vidarsumo commented 2 years ago

How is the minimum number of observation determined?

Here I have 91 observation per ID with lookback = 52.

suppressPackageStartupMessages(library(tidymodels))
library(tft)
set.seed(1)
torch::torch_manual_seed(1)

# 2.0 Preparing data ----
data_tbl <- timetk::walmart_sales_weekly %>%
  select(id, Dept, Date, Weekly_Sales, IsHoliday) %>% 
  mutate(
    Dept = paste0("Dept_", Dept),
    IsHoliday = ifelse(IsHoliday, "yes", "no"))

date_filter <- "2011-10-28"

fit_data <- data_tbl %>% 
  filter(Date <= date_filter)

fit_data %>% 
  group_by(id) %>% 
  summarise(n = n_distinct(Date))

# 91 obs per id

# TFT
rec <- recipe(Weekly_Sales ~ ., data = fit_data) %>%
  timetk::step_timeseries_signature(Date) %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_numeric_predictors())

spec <- tft_dataset_spec(rec, fit_data) %>%
  spec_covariate_index(Date) %>%
  spec_covariate_key(id) %>%
  spec_covariate_known(starts_with("Date_"), IsHoliday) %>%
  spec_covariate_static(Dept) %>% 
  spec_time_splits(lookback = 52, horizon = 52) %>%
  prep()

tft_model <- temporal_fusion_transformer(spec)

Error in `slice_df()`:
! No group has enough observations to statisfy the requested `lookback`.
cregouby commented 1 year ago

Hello @vidarsumo

As your tft_dataset_spec() includes known covariates, you must provide them for the model to train. So the minimum number of observations must be lookback + horizon.

So in your example, if you configure date_filter <- "2012-01-28", you will get n = 104 observation for each id and you will be able to instantiate the model.