quantifish / rlsd

The LSD R package
Other
1 stars 0 forks source link

Add posterior predictive distribution to LF plots #33

Closed quantifish closed 5 months ago

quantifish commented 6 months ago

Currently our LF plots show the data (points), posterior median (line), and some form of credible interval (CI, shaded region) that can barely be seen. It would be better if there were two shaded regions, one being the 95% CI, and the other being the posterior predictive distribution. This can be plotted by simply running the posterior samples through a multinomial distribution in R. Below is some code where I do this for CCSBT. This concept could be rolled out for CPUE etc too.

plot_af <- function(data, object, posterior = NULL, probs = c(0.025, 0.975),
                    years = NULL, fishery = "Indonesian", ...) {

  specs <- data.frame(Year = data$af_year + data$first_yr, 
                      Fishery = c("Indonesian", "Australian")[data$af_fishery - 4], 
                      N = data$af_n, min = data$af_min_age, max = data$af_max_age) %>%
    mutate(id = 1:n())

  obs <- cbind(specs, data$af_obs) %>%
    data.frame() %>%
    pivot_longer(cols = starts_with("X"), names_to = "Age", values_to = "obs") %>%
    mutate(Age = parse_number(Age))

  pred <- cbind(specs, object$report()$af_pred) %>%
    data.frame() %>%
    pivot_longer(cols = starts_with("X"), names_to = "Age", values_to = "pred") %>%
    mutate(Age = parse_number(Age) - 1)

  df <- full_join(obs, pred, by = join_by("Year", "Fishery", "N", "min", "max", "Age", "id")) %>% 
    filter(Fishery == fishery, Age >= min, Age <= max)

  if (!is.null(years)) df <- df %>% filter(Year %in% years)

  p <- ggplot(data = df, aes(x = .data$Age, y = .data$obs)) +
    geom_point(colour = "red") +
    geom_line(aes(y = .data$pred), linetype = "dashed") +
    labs(x = "Age", y = "Proportion") +
    facet_wrap(Year ~ ., ...) +
    scale_x_continuous(breaks = pretty_breaks()) +
    scale_y_continuous(limits = c(0, NA), expand = expansion(mult = c(0, 0.05)))

  if (!is.null(posterior)) {
    df0 <- get_posterior(object = object, posterior = posterior, pars = "af_pred") %>%
      mutate(Age = rep(0:30, each = 85)[id]) %>%
      mutate(id = rep(1:85, 31)[id]) %>%
      left_join(specs, by = join_by("id")) %>%
      select(-id, -output) %>%
      rename(pred = value) %>%
      filter(Fishery == fishery)

    df1 <- df0 %>% pivot_wider(names_from = Age, values_from = pred)
    prob <- as.matrix(df1 %>% select(`0`:`30`))

    df_ppred <- t(mapply(rmultinom, n = 1, size = df1$N, prob = split(x = prob, f = c(row(prob)))))
    df_ppred <- df_ppred / rowSums(df_ppred)

    dfpp <- cbind(df1 %>% select(chain, iter, Year, Fishery, N, min, max), df_ppred) %>%
      pivot_longer(cols = !chain:max, names_to = "Age", values_to = "ppred") %>%
      mutate(Age = as.numeric(Age) - 1)

    df_mcmc <- full_join(df0, dfpp, by = join_by("chain", "iter", "Age", "Year", "Fishery", "N", "min", "max")) %>%
      filter(Age >= min, Age <= max)

    p <- p + 
      stat_summary(data = df_mcmc, geom = "ribbon", alpha = 0.5,
                   aes(y = ppred),
                   fun.min = function(x) quantile(x, probs = probs[1]),
                   fun.max = function(x) quantile(x, probs = probs[2])) +
      stat_summary(data = df_mcmc, geom = "ribbon", alpha = 0.5, 
                   aes(y = pred),
                   fun.min = function(x) quantile(x, probs = probs[1]),
                   fun.max = function(x) quantile(x, probs = probs[2])) +
      stat_summary(data = df_mcmc, aes(y = pred), geom = "line", fun = median)
  }

  return(p)
}

lf_1