tidyverts / fabletools

General fable features useful for extension packages
http://fabletools.tidyverts.org/
89 stars 31 forks source link

new_data argument in forecast not behaving properly #312

Closed PascalKieslich closed 2 years ago

PascalKieslich commented 3 years ago

Thanks a lot for your great work on the fable and fabletools packages!

In encountered an issue when working with the forecast function. Specifically, it seems to ignore the specific values that are provided in the new data argument.

If I run the following code adapted from the examples of the documentation:

library(fable)
library(tsibble)
library(tsibbledata)
library(dplyr)
# library(tidyr)
aus_economy <- global_economy %>% 
  filter(Country == "Australia")
fit <- aus_economy %>% 
  model(lm = ARIMA(log(GDP) ~ Population))
future_aus <- new_data(aus_economy, n = 10) %>% 
  mutate(Population = last(aus_economy$Population))
fit %>% 
  forecast(new_data = future_aus)

I get the following output:

# A fable: 10 x 6 [1Y]
# Key:     Country, .model [1]
   Country   .model  Year              GDP   .mean Population
   <fct>     <chr>  <dbl>           <dist>   <dbl>      <dbl>
 1 Australia lm      2018 t(N(28, 0.0091)) 1.35e12   24598933
 2 Australia lm      2019  t(N(28, 0.024)) 1.41e12   24598933
 3 Australia lm      2020   t(N(28, 0.04)) 1.47e12   24598933
 4 Australia lm      2021  t(N(28, 0.053)) 1.53e12   24598933
 5 Australia lm      2022  t(N(28, 0.064)) 1.59e12   24598933
 6 Australia lm      2023  t(N(28, 0.073)) 1.64e12   24598933
 7 Australia lm      2024   t(N(28, 0.08)) 1.69e12   24598933
 8 Australia lm      2025  t(N(28, 0.085)) 1.73e12   24598933
 9 Australia lm      2026  t(N(28, 0.089)) 1.77e12   24598933
10 Australia lm      2027  t(N(28, 0.092)) 1.80e12   24598933

If I change it to

fit %>% 
  forecast(new_data = future_aus[1:5,])

I get the following output:

# A fable: 5 x 6 [1Y]
# Key:     Country, .model [1]
  Country   .model  Year              GDP   .mean Population
  <fct>     <chr>  <dbl>           <dist>   <dbl>      <dbl>
1 Australia lm      2018 t(N(28, 0.0091)) 1.35e12   24598933
2 Australia lm      2019  t(N(28, 0.024)) 1.41e12   24598933
3 Australia lm      2020   t(N(28, 0.04)) 1.47e12   24598933
4 Australia lm      2021  t(N(28, 0.053)) 1.53e12   24598933
5 Australia lm      2022  t(N(28, 0.064)) 1.59e12   24598933

So far, that makes sense.

However, if I now run the following code (essentially predicting the years 2023-2027):

fit %>% 
  forecast(new_data = future_aus[6:10,])

I get the same predictions as I got for the years 2018-2022 above.

# A fable: 5 x 6 [1Y]
# Key:     Country, .model [1]
  Country   .model  Year              GDP   .mean Population
  <fct>     <chr>  <dbl>           <dist>   <dbl>      <dbl>
1 Australia lm      2023 t(N(28, 0.0091)) 1.35e12   24598933
2 Australia lm      2024  t(N(28, 0.024)) 1.41e12   24598933
3 Australia lm      2025   t(N(28, 0.04)) 1.47e12   24598933
4 Australia lm      2026  t(N(28, 0.053)) 1.53e12   24598933
5 Australia lm      2027  t(N(28, 0.064)) 1.59e12   24598933

This does not seem to be right for me? Shouldn't the predictions match the predictions for the years 2023-2027 in the first output above?

Thanks again for your work!

Best, Pascal

mitchelloharawild commented 2 years ago

Thanks for pointing this issue out. The intended behaviour here is to produce the forecasts for the requested time points as you had expected. This is possible and works for some models such as TSLM(). This is currently not implemented in ARIMA(), and there should be an error in the forecast(<ARIMA>) method to safeguard against this issue. This was added in https://github.com/tidyverts/fable/commit/efaf92567440e918263b7548f4846ccb9b405496 and will be included in the next release.


library(fable)
#> Loading required package: fabletools
library(tsibble)
#> 
#> Attaching package: 'tsibble'
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, union
library(tsibbledata)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
# library(tidyr)
aus_economy <- global_economy %>% 
  filter(Country == "Australia")
fit <- aus_economy %>% 
  model(lm = ARIMA(log(GDP) ~ Population))
#> Warning in sqrt(diag(best$var.coef)): NaNs produced
future_aus <- new_data(aus_economy, n = 10) %>% 
  mutate(Population = last(aus_economy$Population))
fit %>% 
  forecast(new_data = future_aus)
#> # A fable: 10 x 6 [1Y]
#> # Key:     Country, .model [1]
#>    Country   .model  Year              GDP   .mean Population
#>    <fct>     <chr>  <dbl>           <dist>   <dbl>      <dbl>
#>  1 Australia lm      2018 t(N(28, 0.0091)) 1.35e12   24598933
#>  2 Australia lm      2019  t(N(28, 0.024)) 1.41e12   24598933
#>  3 Australia lm      2020   t(N(28, 0.04)) 1.47e12   24598933
#>  4 Australia lm      2021  t(N(28, 0.053)) 1.53e12   24598933
#>  5 Australia lm      2022  t(N(28, 0.064)) 1.59e12   24598933
#>  6 Australia lm      2023  t(N(28, 0.073)) 1.64e12   24598933
#>  7 Australia lm      2024   t(N(28, 0.08)) 1.69e12   24598933
#>  8 Australia lm      2025  t(N(28, 0.085)) 1.73e12   24598933
#>  9 Australia lm      2026  t(N(28, 0.089)) 1.77e12   24598933
#> 10 Australia lm      2027  t(N(28, 0.092)) 1.80e12   24598933

fit %>% 
  forecast(new_data = future_aus[6:10,])
#> Error: Problem with `mutate()` column `lm`.
#> ℹ `lm = (function (object, ...) ...`.
#> x Forecasts from an ARIMA model must start one step beyond the end of the trained data.

Created on 2021-07-26 by the reprex package (v2.0.0)

PascalKieslich commented 2 years ago

Thanks for the clarification and the corresponding change in the package!