Closed spsanderson closed 2 years ago
Make a helper function for this, but is also exported for use:
ts_arima_sim <- function(
.data, .horizon = 12, .Arima_Model = NULL
){
# Tidyeval ----
horizon <- as.integer(.horizon)
model <- .Arima_Model
s <- lapply(1:num_sims, function(i) {
# Set Var
sim <- sim_df <- x <- y <- NULL
sim <- stats::simulate(object = model, nsim = horizon)
sim_df <- base::data.frame(
x = base::as.numeric(stats::time(sim)),
y = base::as.numeric(sim)
)
sim_df$n <- base::paste0("sim_", i)
return(sim_df)
})
s1 <- s %>%
dplyr::bind_rows() %>%
dplyr::group_by(x) %>%
dplyr::summarise(p50 = stats::median(y))
sim_output <- s %>%
dplyr::bind_rows() %>%
tidyr::pivot_wider(names_from = n, values_from = y) %>%
dplyr::select(-x) %>%
stats::ts(
start = stats::start(stats::simulate(object = model, nsim = horizon)),
frequency = stats::frequency(stats::simulate(object = model, nsim = horizon))
)
# Make s into a tibble
s_tbl <- purrr::map_dfr(s, as_tibble) %>%
dplyr::group_by(n) %>%
dplyr::mutate(id = dplyr::row_number(n)) %>%
dplyr::ungroup() %>%
dplyr::rename(value = y)
# Return ----
output <- list(
data = list(
s = s,
s1 = s1,
sim_output = sim_output,
s_tbl = s_tbl
)
)
attr(output, "model_method") <- paste0(
"ARIMA ",
deparse(model$call$order),
" ",
deparse(model$call$seasonal)
)
return(output)
}
ts_a <- ts_arima_sim(.data = data, .horizon = 12, .Arima_Model = nao)
if(class(.model) == "list"){
atb <- attributes(ts_a)
paste0(
"User generated ARIMA(",
toString(atb$.model$order),
") AR: ",
toString(atb$.model$ar),
" MA: ",
toString(atb$.model$ma),
" SD: ",
round(atb$.sd, 2)
)
}
Use the above function inside ts_forecaste_simulator()
ts_forecast_simulator <- function(.model,
.horizon = 4,
.iterations = 25,
.sim_color = "steelblue",
.alpha = 0.328,
.data) {
# Setting variables
x <- y <- s <- s1 <- sim_output <- p <- output <- NULL
# Checks ----
if (!any(class(.model) %in% c("ARIMA", "ets", "nnetar", "Arima", "list"))) {
stop("The .model argument is not valid")
}
if (.alpha < 0 || .alpha > 1) {
stop("The value of the '.alpha' argument is invalid")
}
if (!is.numeric(.iterations)) {
stop("The value of the '.iterations' argument is not valid")
}
if (!is.numeric(.horizon)) {
stop("The value of the '.horizon' argument is not valid")
} else if (.horizon %% 1 != 0) {
stop("The '.horizon' argument is not integer")
} else if (.horizon < 1) {
stop("The value of the '.horizon' argument is not valid")
}
if (!is.data.frame(.data)) {
stop(call. = FALSE, "You must provide a data.frame/tibble to this function.")
}
# Data ----
data_tbl <- .data
# Get index of provided data
data_ts_index <- timetk::tk_index(data = data_tbl)
# Make future tbl
future_tbl <- timetk::tk_make_future_timeseries(
idx = data_ts_index
, length_out = .horizon
) %>%
tibble::as_tibble() %>%
dplyr::mutate(id = dplyr::row_number())
# Manipulation ----
# Simulation
if (class(model)[1] == "forecast_ARIMA") {
ts_a <- ts_arima_sim(
.data = data_tbl,
.horizon = .horizon,
.Arima_Model = .model
)
# Make a model time series tibble ----
model_ts_tbl <- timetk::tk_tbl(data_tbl, timetk_idx = TRUE, silent = TRUE)
names(model_ts_tbl)[1] <- "index"
# Join future_tbl to ts_a
s <- ts_a$data$s
s_joined_tbl <- ts_a$data$s_tbl %>%
dplyr::left_join(future_tbl, by = c("id"="id")) %>%
dplyr::select(value.y, dplyr::everything()) %>%
dplyr::rename(index = value.y) %>%
dplyr::rename(y = value.x)
s1 <- ts_a$data$s1
s1_tbl <- ts_a$data$s1 %>%
dplyr::mutate(id = dplyr::row_number()) %>%
dplyr::left_join(future_tbl, by = c("id"="id")) %>%
dplyr::select(value, dplyr::everything()) %>%
dplyr::rename(index = value)
} else {
s <- lapply(1:.iterations, function(i) {
# Set Var
sim <- sim_df <- x <- y <- NULL
sim <- stats::simulate(.model, nsim = .horizon)
sim_df <- base::data.frame(
x = base::as.numeric(stats::time(sim)),
y = base::as.numeric(sim)
)
sim_df$n <- base::paste0("sim_", i)
return(sim_df)
})
s1 <- s %>%
dplyr::bind_rows() %>%
dplyr::group_by(x) %>%
dplyr::summarise(p50 = stats::median(y))
# Simulation Output
sim_output <- s %>%
dplyr::bind_rows() %>%
tidyr::pivot_wider(names_from = n, values_from = y) %>%
dplyr::select(-x) %>%
stats::ts(
start = stats::start(stats::simulate(.model, nsim = 1)),
frequency = stats::frequency(stats::simulate(.model, nsim = 1))
)
# Make s into a tibble
s_tbl <- purrr::map_dfr(s, as_tibble) %>%
dplyr::group_by(n) %>%
dplyr::mutate(id = dplyr::row_number(n)) %>%
dplyr::ungroup()
# Make a model time series tibble
model_ts_tbl <- timetk::tk_tbl(.model$x, timetk_idx = TRUE)
# Get the timetk index of the
data_ts_index <- timetk::tk_index(data = data_tbl)
future_tbl <- timetk::tk_make_future_timeseries(
idx = data_ts_index,
length_out = .horizon
) %>%
tibble::as_tibble() %>%
dplyr::mutate(id = dplyr::row_number())
s_joined_tbl <- s_tbl %>%
dplyr::left_join(future_tbl, by = c("id" = "id")) %>%
dplyr::select(value, dplyr::everything()) %>%
dplyr::rename(index = value)
s1_tbl <- s1 %>%
dplyr::mutate(id = dplyr::row_number()) %>%
dplyr::left_join(future_tbl, by = c("id" = "id")) %>%
dplyr::select(value, dplyr::everything()) %>%
dplyr::rename(index = value)
}
# ggplot object
model_method <- if(class(.model)[1] == "forecast_ARIMA"){
atb <- attributes(ts_a)
atb$model_method
} else {
model_extraction_helper(.fit_object = .model)
}
g <- ggplot2::ggplot(
data = model_ts_tbl,
ggplot2::aes(x = index, y = value)
) +
ggplot2::geom_line() +
ggplot2::geom_line(
data = s1_tbl,
ggplot2::aes(x = index, y = p50),
color = "red",
size = 1
) +
ggplot2::geom_line(
data = s_joined_tbl,
ggplot2::aes(x = index, y = y, group = n),
alpha = .alpha,
color = .sim_color
) +
ggplot2::theme_minimal() +
ggplot2::labs(
title = glue::glue("Model: {model_method}, Iterations: {.iterations}")
)
p <- plotly::plot_ly()
for (i in 1:.iterations) {
p <- p %>%
plotly::add_lines(
x = s[[i]]$x,
y = s[[i]]$y,
line = list(color = .sim_color),
opacity = .alpha,
showlegend = FALSE,
name = paste("Sim", i, sep = " ")
)
}
# Plotly Plot
p <- p %>% plotly::add_lines(
x = s1$x,
y = s1$p50,
line = list(
color = "#00526d",
dash = "dash",
width = 3
), name = "Median"
)
p <- p %>%
plotly::add_lines(
x = stats::time(.model$x),
y = .model$x,
line = list(color = "#00526d"),
name = "Actual"
)
output <- list(
plotly_plot = p,
ggplot = g,
forecast_sim = sim_output,
forecast_sim_tbl = s_tbl,
time_series = .model$x,
input_data = model_ts_tbl,
sim_ts_tbl = s_joined_tbl
)
# Return ----
return(output)
}