tidyverts / fabletools

General fable features useful for extension packages
http://fabletools.tidyverts.org/
90 stars 31 forks source link

refit() drops lst_mdl class from refitted object #187

Closed ghost closed 4 years ago

ghost commented 4 years ago

refit() returns the refitted model as a list object rather than a lst_mdl, which means that, for example, forecast() doesn't know what to do with it.

(Edited with a better reprex that shows something more like my actual use case).

library(fable)
#> Loading required package: fabletools

lung_deaths_female <- as_tsibble(fdeaths)

partial_fit <- lung_deaths_female %>%
  tsibble::filter_index(. ~ "1978-12-31") %>%
  model(mdl = ARIMA(value))

full_fit <- partial_fit %>%
  refit(lung_deaths_female)

partial_fit %>%
  report
#> Series: value 
#> Model: ARIMA(0,0,0)(1,1,0)[12] 
#> 
#> Coefficients:
#>          sar1
#>       -0.5046
#> s.e.   0.1098
#> 
#> sigma^2 estimated as 9500:  log likelihood=-289.18
#> AIC=582.37   AICc=582.63   BIC=586.11

full_fit %>%
  report
#> Series: value 
#> Model: ARIMA(0,0,0)(1,1,0)[12] 
#> 
#> Coefficients:
#>          sar1
#>       -0.5046
#> s.e.   0.1098
#> 
#> sigma^2 estimated as 7713:  log likelihood=-355.42
#> AIC=712.84   AICc=712.9   BIC=714.93

class(partial_fit$mdl)
#> [1] "lst_mdl" "list"

class(full_fit$mdl)
#> [1] "list"

forecast(partial_fit)
#> # A fable: 24 x 4 [1M]
#> # Key:     .model [1]
#>    .model    index value .distribution
#>    <chr>     <mth> <dbl> <dist>       
#>  1 mdl    1979 Jan  829. N(829, 9500) 
#>  2 mdl    1979 Feb  756. N(756, 9500) 
#>  3 mdl    1979 Mar  700. N(700, 9500) 
#>  4 mdl    1979 Apr  595. N(595, 9500) 
#>  5 mdl    1979 May  516. N(516, 9500) 
#>  6 mdl    1979 Jun  419. N(419, 9500) 
#>  7 mdl    1979 Jul  421. N(421, 9500) 
#>  8 mdl    1979 Aug  355. N(355, 9500) 
#>  9 mdl    1979 Sep  387  N(387, 9500) 
#> 10 mdl    1979 Oct  407. N(407, 9500) 
#> # ... with 14 more rows

forecast(full_fit)
#> Error in UseMethod("forecast"): no applicable method for 'forecast' applied to an object of class "list"

Created on 2020-04-13 by the reprex package (v0.3.0)

Session info ``` r sessionInfo() #> R version 3.6.2 (2019-12-12) #> Platform: x86_64-w64-mingw32/x64 (64-bit) #> Running under: Windows 10 x64 (build 17763) #> #> Matrix products: default #> #> locale: #> [1] LC_COLLATE=English_United States.1252 #> [2] LC_CTYPE=English_United States.1252 #> [3] LC_MONETARY=English_United States.1252 #> [4] LC_NUMERIC=C #> [5] LC_TIME=English_United States.1252 #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] fable_0.1.2 fabletools_0.1.2 #> #> loaded via a namespace (and not attached): #> [1] Rcpp_1.0.3 urca_1.3-0 pillar_1.4.3 compiler_3.6.2 #> [5] highr_0.8 tools_3.6.2 digest_0.6.23 lattice_0.20-38 #> [9] nlme_3.1-142 tsibble_0.8.6 lubridate_1.7.4 evaluate_0.14 #> [13] tibble_2.1.3 lifecycle_0.1.0 gtable_0.3.0 anytime_0.3.7 #> [17] pkgconfig_2.0.3 rlang_0.4.3 cli_2.0.1 yaml_2.2.0 #> [21] xfun_0.12 dplyr_0.8.3 stringr_1.4.0 knitr_1.27 #> [25] feasts_0.1.2 generics_0.0.2 vctrs_0.2.2 grid_3.6.2 #> [29] tidyselect_0.2.5 glue_1.3.1 R6_2.4.1 fansi_0.4.1 #> [33] rmarkdown_2.1 purrr_0.3.3 ggplot2_3.2.1 tidyr_1.0.2 #> [37] magrittr_1.5 ellipsis_0.3.0 scales_1.1.0 htmltools_0.4.0 #> [41] assertthat_0.2.1 colorspace_1.4-1 utf8_1.1.4 stringi_1.4.4 #> [45] lazyeval_0.2.2 munsell_0.5.0 crayon_1.3.4 ```
mitchelloharawild commented 4 years ago

This should already be fixed in the development version.

library(fable)
#> Loading required package: fabletools

lung_deaths_female <- as_tsibble(fdeaths)

partial_fit <- lung_deaths_female %>%
  tsibble::filter_index(. ~ "1978-12-31") %>%
  model(mdl = ARIMA(value))

full_fit <- partial_fit %>%
  refit(lung_deaths_female)

partial_fit %>%
  report
#> Series: value 
#> Model: ARIMA(0,0,0)(1,1,0)[12] 
#> 
#> Coefficients:
#>          sar1
#>       -0.5058
#> s.e.   0.1086
#> 
#> sigma^2 estimated as 9299:  log likelihood=-294.67
#> AIC=593.34   AICc=593.6   BIC=597.12

full_fit %>%
  report
#> Series: value 
#> Model: ARIMA(0,0,0)(1,1,0)[12] 
#> 
#> Coefficients:
#>          sar1
#>       -0.5058
#> s.e.   0.1086
#> 
#> sigma^2 estimated as 7710:  log likelihood=-355.42
#> AIC=712.83   AICc=712.9   BIC=714.93

class(partial_fit$mdl)
#> [1] "lst_mdl"    "vctrs_vctr"

class(full_fit$mdl)
#> [1] "lst_mdl"    "vctrs_vctr"

forecast(partial_fit)
#> # A fable: 24 x 4 [1M]
#> # Key:     .model [1]
#>    .model    index        value .mean
#>    <chr>     <mth>       <dist> <dbl>
#>  1 mdl    1979 Feb N(755, 9299)  755.
#>  2 mdl    1979 Mar N(700, 9299)  700.
#>  3 mdl    1979 Apr N(595, 9299)  595.
#>  4 mdl    1979 May N(516, 9299)  516.
#>  5 mdl    1979 Jun N(419, 9299)  419.
#>  6 mdl    1979 Jul N(421, 9299)  421.
#>  7 mdl    1979 Aug N(355, 9299)  355.
#>  8 mdl    1979 Sep N(387, 9299)  387 
#>  9 mdl    1979 Oct N(407, 9299)  407.
#> 10 mdl    1979 Nov N(418, 9299)  418.
#> # … with 14 more rows

forecast(full_fit)
#> # A fable: 24 x 4 [1M]
#> # Key:     .model [1]
#>    .model    index        value .mean
#>    <chr>     <mth>       <dist> <dbl>
#>  1 mdl    1980 Jan N(808, 7710)  808.
#>  2 mdl    1980 Feb N(819, 7710)  819.
#>  3 mdl    1980 Mar N(732, 7710)  732.
#>  4 mdl    1980 Apr N(579, 7710)  579.
#>  5 mdl    1980 May N(504, 7710)  504.
#>  6 mdl    1980 Jun N(438, 7710)  438.
#>  7 mdl    1980 Jul N(418, 7710)  418.
#>  8 mdl    1980 Aug N(370, 7710)  370.
#>  9 mdl    1980 Sep N(390, 7710)  390.
#> 10 mdl    1980 Oct N(421, 7710)  421.
#> # … with 14 more rows

Created on 2020-04-14 by the reprex package (v0.3.0)