tidyverts / fabletools

General fable features useful for extension packages
http://fabletools.tidyverts.org/
89 stars 31 forks source link

Inconsistent ability to generate prediction intervals for transformed variables via boostrap #365

Closed dcaseykc closed 1 year ago

dcaseykc commented 1 year ago

I've been exploring how to get prediction intervals for transformed outcome variables using fable/fabletools (which is a delightful package!). Recently, I noticed that the dev branch includes code that appears to allow mint reconciliation for certain classes of transformed outcome variables (e.g. log). However, other transforms (e.g. log(Y+1)) don't seem to work (and instead, return only the point forecast. Is this the expected behavior?

library('fable', quietly = TRUE)
library('tsibble', quietly = TRUE)
library('lubridate', quietly = TRUE)
library('dplyr', quietly = TRUE)
prison <- readr::read_csv("https://OTexts.com/fpp3/extrafiles/prison_population.csv") %>%
  mutate(Quarter = yearquarter(Date)) %>%
  select(-Date)  %>%
  as_tsibble(key = c(Gender, Legal, State, Indigenous),
             index = Quarter) %>%
  relocate(Quarter)

prison_gts <- prison %>%
  aggregate_key(Gender * Legal * State, Count = sum(Count)/1e3)

fit <- prison_gts %>%
  filter(year(Quarter) <= 2014) %>%
  model(baselog = ARIMA(log(Count)),
        baselogp1 = ARIMA(log(Count+1)))

rec = fit %>%
  reconcile(
    m1 = min_trace(baselog, method = "mint_shrink"),
    m2 = min_trace(baselogp1, method = "mint_shrink")
  )

fcast = rec %>% select(-baselog)  %>% forecast(h = 1, bootstrap = TRUE, times = 10)

fcast %>%
  filter(is_aggregated(State), is_aggregated(Gender),
         is_aggregated(Legal), Quarter == make_yearquarter(2015, 1))

#> # A fable: 3 x 7 [1Q]
#> # Key:     Gender, Legal, State, .model [3]
#>   Gender       Legal        State        .model    Quarter      Count .mean
#>   <chr*>       <chr*>       <chr*>       <chr>       <qtr>     <dist> <dbl>
#> 1 <aggregated> <aggregated> <aggregated> baselogp1 2015 Q1 sample[10]  35.3
#> 2 <aggregated> <aggregated> <aggregated> m1        2015 Q1 sample[10]  35.1
#> 3 <aggregated> <aggregated> <aggregated> m2        2015 Q1   35.01382  35.0

Created on 2022-09-19 by the reprex package (v2.0.1)

mitchelloharawild commented 1 year ago

Sorry, I missed this stackoverflow post somehow. Thanks for your reprex, it is really helpful.

Transformed distributions are unlikely to be elliptical (current requirement for distributional reconciliation; https://www.monash.edu/business/ebs/research/publications/ebs/wp26-2020.pdf), so you would need to use sample paths.

As you point out, reconciling sample paths is added in the dev version of this package. This approach should work for reconciling any forecast, regardless of the shape of their distribution.

There seems to be a bug in either the {distributional} or {fabletools} package when producing sample distributions that are transformed multiple times, which is preventing the forecasting method from identifying it as a sample distribution.

mitchelloharawild commented 1 year ago

Reprex:

library(fabletools)
z <- as_tsibble(USAccDeaths) %>% 
  model(
    l = fable::SNAIVE(value),
    lp1 = fable::SNAIVE(log(value + 1))
  ) %>% 
  forecast(h=1, bootstrap = TRUE, times = 10)

str(z$value)
#> dist [1:2] 
#> $ :List of 1
#>  ..$ x: num [1:10] 8619 7720 7395 8071 7242 ...
#>  ..- attr(*, "class")= chr [1:2] "dist_sample" "dist_default"
#> $ : dist [1:1] 
#>  ..$ :List of 1
#>  .. ..$ x: num [1:10] 8154 8062 7334 8340 8381 ...
#>  .. ..- attr(*, "class")= chr [1:2] "dist_sample" "dist_default"
#> @ vars: chr "value"

Created on 2022-09-20 by the reprex package (v2.0.1)

The structure of these two objects should be the same.

mitchelloharawild commented 1 year ago

Seems to be an issue in {distributional}. MRE:

library(distributional)
z <- dist_sample(list(rnorm(10)))
y <- z + 1 - 1
waldo::compare(z, y)
#> `class(old[[1]])`: "dist_sample"  "dist_default"       
#> `class(new[[1]])`: "distribution" "vctrs_vctr"   "list"
#> 
#> `names(old[[1]])` is a character vector ('x')
#> `names(new[[1]])` is absent
#> 
#> `old[[1]][[1]]` is a double vector (-0.150150416298483, 0.589858719451371, -1.30637657466461, -0.661723727293455, 1.54528081833613, ...)
#> `new[[1]][[1]]` is an S3 object of class <distribution/vctrs_vctr/list>, a list

Created on 2022-09-20 by the reprex package (v2.0.1)

mitchelloharawild commented 1 year ago

{distributional} issue fixed in https://github.com/mitchelloharawild/distributional/commit/2a8edf67828ed6da39c8deba0447ac4a1b03833a {fabletools} handling of distributions needs some work for this change.

mitchelloharawild commented 1 year ago

Should be okay now:

library('fable', quietly = TRUE)
library('tsibble', quietly = TRUE)
#> 
#> Attaching package: 'tsibble'
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, union
library('lubridate', quietly = TRUE)
#> 
#> Attaching package: 'lubridate'
#> The following object is masked from 'package:tsibble':
#> 
#>     interval
#> The following objects are masked from 'package:base':
#> 
#>     date, intersect, setdiff, union
library('dplyr', quietly = TRUE)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
prison <- readr::read_csv("https://OTexts.com/fpp3/extrafiles/prison_population.csv") %>%
  mutate(Quarter = yearquarter(Date)) %>%
  select(-Date)  %>%
  as_tsibble(key = c(Gender, Legal, State, Indigenous),
             index = Quarter) %>%
  relocate(Quarter)
#> Rows: 3072 Columns: 6
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr  (4): State, Gender, Legal, Indigenous
#> dbl  (1): Count
#> date (1): Date
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

prison_gts <- prison %>%
  aggregate_key(Gender * Legal * State, Count = sum(Count)/1e3)

fit <- prison_gts %>%
  filter(year(Quarter) <= 2014) %>%
  model(baselog = ARIMA(log(Count)),
        baselogp1 = ARIMA(log(Count+1)))

rec = fit %>%
  reconcile(
    m1 = min_trace(baselog, method = "mint_shrink"),
    m2 = min_trace(baselogp1, method = "mint_shrink")
  )

fcast = rec %>% select(-baselog)  %>% forecast(h = 1, bootstrap = TRUE, times = 10)

fcast %>%
  filter(is_aggregated(State), is_aggregated(Gender),
         is_aggregated(Legal), Quarter == make_yearquarter(2015, 1))
#> # A fable: 3 x 7 [1Q]
#> # Key:     Gender, Legal, State, .model [3]
#>   Gender       Legal        State        .model    Quarter      Count .mean
#>   <chr*>       <chr*>       <chr*>       <chr>       <qtr>     <dist> <dbl>
#> 1 <aggregated> <aggregated> <aggregated> baselogp1 2015 Q1 sample[10]  35.3
#> 2 <aggregated> <aggregated> <aggregated> m1        2015 Q1 sample[10]  35.1
#> 3 <aggregated> <aggregated> <aggregated> m2        2015 Q1 sample[10]  35.0

Created on 2022-09-20 by the reprex package (v2.0.1)