mlr-org / mlr3temporal

Forecasting for mlr3
https://mlr3temporal.mlr-org.com
GNU Lesser General Public License v3.0
20 stars 2 forks source link

Proposal to add "forecast.persistence" learner for seasonal persistence prediction #83

Open MarkusLeyser opened 3 months ago

MarkusLeyser commented 3 months ago

Description: I am interested in contributing to the mlr3temporal package and would like to propose the addition of a simple "forecast.persistence" learner. (This would be my first contribution to the mlr3 project. 🙂)

Motivation and Idea: I am using mlr3 for a benchmarking study on the prediction of photovoltaic power output. In the context of photovoltaics, seasonal persistence prediction is a common (naive) baseline for evaluating forecasting models. This approach uses observations from the latest seasonal cycle for prediction. In contrast, regular persistence prediction would forecast the last available value. For example, in my case, a seasonal persistence prediction would correspond to the photovoltaic power output from 24 hours ago. While this approach looks simple, it can be surprisingly hard to beat.

Both regular and seasonal persistence predictions are valuable baselines for many time series modeling applications. Therefore, I see a benefit in providing a convenient and simple "forecast.persistence" learner in mlr3temporal.

Implementation: There are two viable options for implementing this simple learner:

  1. Using the data.table package:

    • Pros: No additional dependencies; core functionality can be realized using shift().
    • Cons: Limited functionality.
  2. Using the forecast package:

    • Pros: The (seasonal) persistence prediction can be represented as a suitably specified ARIMA model (seasonal differencing of order 1 with the appropriate period). It also allows for further functionality expansion, in particular with forecast::findfrequency(x), which determines the dominant frequency of a time series.
    • Cons: Requires a package that is not distributed with R itself. Considering that an ARIMA learner (forecast.arima) is already implemented in mlr3temporal, I personally believe that option 2 is still viable.

In both cases, it is necessary to define how the learner behaves when a complete period is not present in the training data, i.e., when the lag order is greater than the length of the data. Hence, the existing forecast.arima learner is not sufficient.

Request: I have attached a few lines of concept code for reference (modified versions of forecast.arima and forecast.average, respectively). I would appreciate it if you could provide feedback on whether such a contribution would be welcomed by the maintainers.

Thank you for your time.

LearnerRegrForecastPersistenceShift:

#' @title Lag-based Persistence Forecast Learner
#'
#' @name mlr_learners_regr.persistence_shift
#'
#' @description
#' Persistence model using data.table::shift()
#' Calls [data.table::shift] from package \CRANpkg{data.table} to return a (seasonal) persistence forecast.
#'
#' @templateVar id forecast.persistence_shift
#' @template learner
#'
#' @template seealso_learner
#' @export
#' @template example
LearnerRegrForecastPersistenceShift = R6::R6Class("LearnerRegrForecastPersistenceShift",
                                                  inherit = LearnerForecast,

                                                  public = list(

                                                    #' @description
                                                    #' Creates a new instance of this [R6][R6::R6Class] class.
                                                    initialize = function() {
                                                      ps = ps(
                                                        period = p_int(default = 1, lower = 1, tags = "train")
                                                      )

                                                      super$initialize(
                                                        id = "forecast.persistence_shift",
                                                        feature_types = "numeric",
                                                        predict_types = c("response"),
                                                        packages = "forecast",
                                                        param_set = ps,
                                                        properties = c("univariate"),
                                                        man = "mlr3temporal::mlr_learners_regr.persistence_shift"
                                                      )
                                                    },

                                                    #' @description
                                                    #' Returns forecasts after the last training instance.
                                                    #'
                                                    #' @param h (`numeric(1)`)\cr
                                                    #'   Number of steps ahead to forecast. Default is 10.
                                                    #'
                                                    #' @param task ([Task]).
                                                    #'
                                                    #' @param newdata ([data.frame()])\cr
                                                    #'   Ignored
                                                    #'
                                                    #' @return [Prediction].
                                                    forecast = function(h = 10, task, newdata = NULL) {
                                                      h = assert_int(h, lower = 1, coerce = TRUE)
                                                      indices = 1:h%%length(self$model$last_values)
                                                      indices[indices == 0] = length(self$model$last_values)
                                                      forecast = self$model$last_values[indices]
                                                      response = as.data.table(forecast)
                                                      colnames(response) = task$target_names

                                                      truth = copy(response)
                                                      truth[, colnames(truth) := 0]
                                                      p = PredictionForecast$new(task,
                                                                                 response = response, truth = truth,
                                                                                 row_ids = (self$date_span$end$row_id + 1):(self$date_span$end$row_id + h)
                                                      )
                                                    }
                                                  ),

                                                  private = list(
                                                    .train = function(task) {
                                                      span = range(task$date()[[task$date_col]])
                                                      self$date_span = list(
                                                        begin = list(time = span[1], row_id = task$row_ids[1]),
                                                        end = list(time = span[2], row_id = task$row_ids[task$nrow])
                                                      )
                                                      pv = self$param_set$get_values(tags = "train")
                                                      period = ifelse(is.null(pv$period), self$param_set$default$period, pv$period)
                                                      x = task$data(cols = task$target_names)[[1L]]
                                                      mean_x = mean(x)
                                                      last_values = c(
                                                        x[max((task$nrow-period+1),1):task$nrow], #already present values
                                                        rep(mean_x, times = max(0, period-task$nrow))#fill up with mean of the data, if period is longer than training data
                                                      )
                                                      list(
                                                        "fill_value" = mean_x,
                                                        "period" = period,
                                                        "fitted" = data.table::shift(x, n = period, type = "lag", fill = mean_x),
                                                        "last_values" = last_values
                                                      )
                                                    },

                                                    .predict = function(task) {
                                                      all_values = c(self$model$fitted, self$model$last_values)
                                                      indices = task$row_ids%%length(all_values)
                                                      indices[indices == 0] = length(all_values)
                                                      response = all_values[indices]
                                                      list("response" = response)
                                                    }
                                                  )
)

#' @include aaa.R
learners[["forecast.persistence_shift"]] = LearnerRegrForecastPersistenceShift

LearnerRegrForecastPersistenceArima:

#' @title Arima-based Persistence Forecast Learner
#'
#' @name mlr_learners_regr.persistence_arima
#'
#' @description
#' Persistence model as an ARIMA model
#' Calls [forecast::Arima] from package \CRANpkg{forecast} with suitable parameter values to return a (seasonal) persistence forecast.
#'
#' @templateVar id forecast.persistence_arima
#' @template learner
#'
#' @template seealso_learner
#' @export
#' @template example
LearnerRegrForecastPersistenceArima = R6::R6Class("LearnerRegrForecastPersistenceArima",
                                       inherit = LearnerForecast,

                                       public = list(

                                         #' @description
                                         #' Creates a new instance of this [R6][R6::R6Class] class.
                                         initialize = function() {
                                           ps = ps(
                                             period = p_int(default = 1, lower = 1, tags = "train")
                                           )

                                           super$initialize(
                                             id = "forecast.persistence_arima",
                                             feature_types = "numeric",
                                             predict_types = c("response", "se"),
                                             packages = "forecast",
                                             param_set = ps,
                                             properties = c("univariate"),
                                             man = "mlr3temporal::mlr_learners_regr.persistence_arima"
                                           )
                                         },

                                         #' @description
                                         #' Returns forecasts after the last training instance.
                                         #'
                                         #' @param h (`numeric(1)`)\cr
                                         #'   Number of steps ahead to forecast. Default is 10.
                                         #'
                                         #' @param task ([Task]).
                                         #'
                                         #' @param newdata ([data.frame()])\cr
                                         #'   New data to predict on.
                                         #'
                                         #' @return [Prediction].
                                         forecast = function(h = 10, task, newdata = NULL) {
                                           h = assert_int(h, lower = 1, coerce = TRUE)
                                           if (length(task$feature_names) > 0) {
                                             newdata = as.matrix(newdata)
                                             forecast = invoke(forecast::forecast, self$model, xreg = newdata)
                                           } else {
                                             forecast = invoke(forecast::forecast, self$model, h = h)
                                           }
                                           response = as.data.table(as.numeric(forecast$mean))
                                           colnames(response) = task$target_names

                                           se = as.data.table(as.numeric(
                                             ci_to_se(width = forecast$upper[, 1] - forecast$lower[, 1], level = forecast$level[1])
                                           ))
                                           colnames(se) = task$target_names

                                           truth = copy(response)
                                           truth[, colnames(truth) := 0]
                                           p = PredictionForecast$new(task,
                                                                      response = response, se = se, truth = truth,
                                                                      row_ids = (self$date_span$end$row_id + 1):(self$date_span$end$row_id + h)
                                           )
                                         }
                                       ),

                                       private = list(
                                         .train = function(task) {
                                           span = range(task$date()[[task$date_col]])
                                           self$date_span = list(
                                             begin = list(time = span[1], row_id = task$row_ids[1]),
                                             end = list(time = span[2], row_id = task$row_ids[task$nrow])
                                           )
                                           pv = self$param_set$get_values(tags = "train")
                                           seasonal = list(order = c(0L, 1L, 0L), period = pv$period)
                                           invoke(forecast::Arima, 
                                                  y = task$data(
                                                    rows = task$row_ids,
                                                    cols = task$target_names
                                                  ), 
                                                  seasonal = seasonal,
                                                  include.mean = FALSE)
                                         },

                                         .predict = function(task) {
                                           se = NULL
                                           fitted_ids = task$row_ids[task$row_ids <= self$date_span$end$row_id]
                                           predict_ids = setdiff(task$row_ids, fitted_ids)

                                           if (length(predict_ids) > 0) {
                                             if (length(task$feature_names) > 0) {
                                               newdata = as.matrix(task$data(cols = task$feature_names, rows = predict_ids))
                                               response_predict = invoke(forecast::forecast, self$model, xreg = newdata)
                                             } else {
                                               response_predict = invoke(forecast::forecast, self$model, h = length(predict_ids))
                                             }

                                             predict_mean = as.data.table(as.numeric(response_predict$mean))
                                             colnames(predict_mean) = task$target_names
                                             fitted.mean = self$fitted_values(fitted_ids)
                                             colnames(fitted.mean) = task$target_names
                                             response = rbind(fitted.mean, predict_mean)
                                             if (self$predict_type == "se") {
                                               predict_se = as.data.table(as.numeric(
                                                 ci_to_se(width = response_predict$upper[, 1] - response_predict$lower[, 1],
                                                          level = response_predict$level[1])
                                               ))
                                               colnames(predict_se) = task$target_names
                                               fitted_se = as.data.table(
                                                 sapply(task$target_names, function(x) rep(sqrt(self$model$sigma2), length(fitted_ids)), simplify = FALSE)
                                               )
                                               se = rbind(fitted_se, predict_se)
                                             }
                                           } else {
                                             response = self$fitted_values(fitted_ids)
                                             if (self$predict_type == "se") {
                                               se = as.data.table(
                                                 sapply(task$target_names, function(x) rep(sqrt(self$model$sigma2), length(fitted_ids)), simplify = FALSE)
                                               )
                                             }
                                           }

                                           list(response = response, se = se)
                                         }
                                       )
)

#' @include aaa.R
learners[["forecast.persistence_arima"]] = LearnerRegrForecastPersistenceArima

Example using the included airpassengers task:

library(mlr3verse)
library(mlr3temporal)
library(R6)
library(mlr3misc)
library(checkmate)
library(data.table)

source("aaa.R")
source("helper.R")
source("LearnerRegrForecastPersistenceArima.R")
source("LearnerRegrForecastPersistenceShift.R")
source("zzz.R")

task <- tsk("airpassengers")

l_arima <- LearnerRegrForecastPersistenceArima$new()
l_arima$param_set$set_values(period = 10)
l_arima$train(task, row_ids = 1:134)
l_arima$predict(task, row_ids = 135:144)
(forecast_arima <- l_arima$forecast(h = 40, task = task))

l_shift <- LearnerRegrForecastPersistenceLag$new()
l_shift$param_set$set_values(period = 10)
l_shift$train(task, row_ids = 1:134)
l_shift$predict(task, row_ids = 135:144)
(forecast_shift <- l_shift$forecast(h = 40, task = task))

library(ggplot2)
ggplot() +
  geom_line(aes(x = task$row_ids, y = task$data()$target)) +
  geom_line(data = as.data.table(forecast_arima)[,c(1,3)], aes(x = row_ids, y = target), linetype = "dashed") +
  labs(x = "row_ids", y = "target", caption = "dashed: prediction")
sebffischer commented 3 months ago

@m-muecke

m-muecke commented 3 months ago

@MarkusLeyser cheers for the detailed issue, what you're describing is just a naive and seasonal naive model, i.e. random walk or ARIMA(0,1,0). These methods should already be implemented natively by the forecast package or by the successor the fable package, i.e. forecast::naive(x) and forecast::snaive(x).

Regarding mlr3temporal, the package is currently in an experimental stage and is currently not being actively developed. Hence, there might be some bugs and usage should be seen with caution. There is currently some experimental development in a successor (mlr3forecast), which focus will be on ML forecasting instead. Depending on the progress etc. the aim will be to integrate some native forecasting methods as well, but then I would add the implementation from the fable package and some others instead.

In the mean-time you can achieve the same by using ARIMA(0,1,0) and ARIMA(0,0,0)(0,1,0) for naive and snaive respectively. Lastly, you're welcome of adding PRs, I just can't guarantee of a quick reviews.

MarkusLeyser commented 3 months ago

Thank you for the quick response (although the package is not currently under active development, as you write). Good to hear that a successor package is planned.

Naive models are exactly what I am referring to above. Given that they can be created using forecast.arima, the benefit of the suggested implementation is quite limited in the first place. It lies mostly in its convenience. (Only some edge cases with limited training data remain unaccounted for with forecast.arima.)

In the light of the future development of mlr3forecast, it might be more valuable to shift my focus to different areas.