vincentarelbundock / marginaleffects

R package to compute and plot predictions, slopes, marginal means, and comparisons (contrasts, risk ratios, odds, etc.) for over 100 classes of statistical and ML models. Conduct linear and non-linear hypothesis tests, or equivalence tests. Calculate uncertainty estimates using the delta method, bootstrapping, or simulation-based inference
https://marginaleffects.com
Other
437 stars 45 forks source link

Partial residual plot #194

Closed vincentarelbundock closed 1 year ago

vincentarelbundock commented 2 years ago

https://strengejacke.github.io/ggeffects/articles/introduction_partial_residuals.html

vincentarelbundock commented 2 years ago

From Twitter:

That was my contribution to {ggeffects}

A version of that function can be found here (I've gone ahead and added support for marginaleffects::predictions()):

https://mattansb.github.io/MSBMisc/reference/get_data_for_grid.html

Feel free to port this!

vincentarelbundock commented 2 years ago

For now, this out of scope. There's an infinite number of model diagnostics that could be implemented, and I want to keep the package simple and the number of functions limited, by focusing on the 4 quantities of interest:

  1. predictions
  2. contrasts
  3. marginal effects
  4. marginal means

This would be rather trivial to implement, so I am still open to it, but I would need someone to convince me that it really belongs in this package.

mattansb commented 1 year ago

FYI, this is quite easy to implement, we can use ggeffects’s residualize_over_grid.data.frame() to do all the heavy lifting, we just need a little wrapper:

residualize_over_grid.predictions <- function(grid, model, ...) {
  model <- attr(grid, "model")

  cond_vars <- attr(grid, "newdata_variables_datagrid")

  ggeffects::residualize_over_grid(
    grid = as.data.frame(grid[c(cond_vars, "estimate")]), 
    model = model, 
    pred_name = "estimate"
  )
}

Now prediction grids can be used with a mix of predictions() + plot_predictions() + some ggplot code:

library(marginaleffects)
library(ggeffects)

library(ggplot2)
library(patchwork)

# Linear models -----------------------------------------------------------

mod <- lm(mpg ~ hp + am * factor(cyl),
          mtcars)

## Example 1 -----

grid <- predictions(mod, newdata = datagrid(hp = unique)) |> 
  ggeffects::residualize_over_grid()

(p1 <- plot_predictions(mod, condition = "hp") + 
    geom_point(aes(hp, estimate), data = grid))


# Compare to ggeffects:
p1_gge <- ggpredict(mod, c("hp"), condition = c(cyl = 8)) |> 
  plot(residuals = TRUE)

patchwork::wrap_plots(p1, p1_gge)


## Example 2 -----
# discretize a numeric

grid <- predictions(mod, newdata = datagrid(cyl = unique, hp = c(100, 200))) |> 
  ggeffects::residualize_over_grid()

(p2 <- plot_predictions(mod, condition = list("cyl", "hp" = c(100, 200))) + 
    geom_point(aes(factor(cyl), estimate, color = factor(hp)), data = grid,
               position = position_dodge(0.4)))


p2_gge <- ggpredict(mod, c("cyl", "hp [100,200]"), condition = c(cyl = 8)) |> 
  plot(residuals = TRUE)

# Compare to ggeffects:
patchwork::wrap_plots(p2, p2_gge)


# GLM ---------------------------------------------------------------------

mod <- glm(mpg ~ hp + am * factor(cyl),
           family = poisson(),
           mtcars)
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.100000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 24.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 16.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000

#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 32.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 33.900000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 13.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 27.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000

grid <- predictions(mod, newdata = datagrid(hp = unique)) |> 
  residualize_over_grid()

(p3 <- plot_predictions(mod, condition = "hp") + 
    geom_point(aes(hp, estimate), data = grid))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf


p3_gge <- ggpredict(mod, c("hp"), condition = c(cyl = 8)) |> 
  plot(residuals = TRUE)

patchwork::wrap_plots(p3, p3_gge) & coord_cartesian(ylim = c(5, 50))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf

Created on 2023-02-10 with reprex v2.0.2

vincentarelbundock commented 1 year ago

Oh, this is cool!

Reopening to make sure I look at this in detail later.

mattansb commented 1 year ago

Okay, so I re-wrote some functions to make them faster and work better with marginaleffects. Hope this helps!

library(marginaleffects)
library(ggplot2)

.is_grid <- function (df) {
  unq <- lapply(df, unique)
  if (prod(sapply(unq, length)) != nrow(df)) {
    return(FALSE)
  }
  df2 <- do.call(expand.grid, args = unq)
  df2$..1 <- 1
  res <- merge(df, df2, by = colnames(df), all = TRUE)
  return(sum(res$..1) == sum(df2$..1))
}

match_grid <- function(grid, data) {
  # 1. Duplicate grid
  grid2 <- grid
  # 2. Remove from duplicate fixed columns and columns not in data
  is_fixed <- sapply(grid2, insight::has_single_value)
  grid2 <- grid2[!is_fixed]
  grid2 <- grid2[intersect(colnames(grid2), colnames(data))]
  # 3. Test if grid
  stopifnot(.is_grid(grid2))

  # 4. For each col in grid, get the unique values
  unqs <- lapply(grid2, unique)
  to_char <- sapply(unqs, function(x) is.character(x) || is.factor(x) || is.logical(x))
  unqs[to_char] <- lapply(unqs[to_char], as.character)

  # 5. For each row in data:
  for (i in seq_len(nrow(data))) {
    for (j in names(unqs)) {
      #   For each column
      if (is.factor(unqs[[j]]) || is.logical(unqs[[j]]) || is.character(unqs[[j]])) {
        # If (logical, factor, char), match exactly. If not matched, return NA.
        jidx <- which(as.character(data[[j]][i]) == as.character(unqs[[j]]))
        if (length(jidx) == 0L) {
          data[[j]][i] <- NA
          next
        }
      } else {
        # If numeric, find closest value in grid values
        jidx <- which.min(abs(data[[j]][i] - unqs[[j]]))
      }
      data[[j]][i] <- unqs[[j]][jidx]
    }
  }

  # 6. Add back columns from (2).
  fixed_cols <- names(is_fixed)[is_fixed]
  data[, fixed_cols] <- grid[1, fixed_cols]

  return(data)
}

residualize_over_grid2 <- function(grid, model, ...) {
  e <- residuals(model, type = "working")
  inv <- insight::link_inverse(model)
  mi <- insight::model_info(model)
  type <- ifelse(mi$link_function == "identity", "response", "link")

  data <- insight::get_data(model)
  grid <- match_grid(grid, data)

  grid2 <- marginaleffects::predictions(mod, newdata = grid, type = type, ...)
  grid2[c("type", "std.error", "statistic", "p.value", "conf.low", "conf.high")] <- NULL

  grid2[["estimate"]] <- inv(grid2[["estimate"]] + e)

  grid2
}

# Linear models -----------------------------------------------------------

mod <- lm(mpg ~ hp + am * factor(cyl),
          mtcars)

## Example 1 -----

grid <- datagrid(model = mod, hp = unique) |> 
  residualize_over_grid2(mod)

(p1 <- plot_predictions(mod, condition = "hp") + 
    geom_point(aes(hp, estimate), data = grid))


# Compare to ggeffects:
p1_gge <- ggeffects::ggpredict(mod, c("hp"), condition = c(cyl = 8)) |> 
  plot(residuals = TRUE)

patchwork::wrap_plots(p1, p1_gge)


## Example 2 -----
# discretize a numeric

grid <- datagrid(model = mod, cyl = unique, hp = c(100, 200)) |> 
  residualize_over_grid2(mod)

(p2 <- plot_predictions(mod, condition = list("cyl", "hp" = c(100, 200))) + 
    geom_point(aes(factor(cyl), estimate, color = factor(hp)), data = grid,
               position = position_dodge(0.4)))


# Compare to ggeffects:
p2_gge <- ggeffects::ggpredict(mod, c("cyl", "hp [100,200]"), condition = c(cyl = 8)) |> 
  plot(residuals = TRUE)

patchwork::wrap_plots(p2, p2_gge)


# GLM ---------------------------------------------------------------------

mod <- glm(mpg ~ hp + am * factor(cyl),
           family = poisson(),
           mtcars)
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.100000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 24.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 16.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000

#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 32.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 33.900000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 13.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 27.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000

grid <- datagrid(model = mod, hp = unique) |> 
  residualize_over_grid2(mod)

(p3 <- plot_predictions(mod, condition = "hp") + 
    geom_point(aes(hp, estimate), data = grid))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf


# Compare to ggeffects:
p3_gge <- ggeffects::ggpredict(mod, c("hp"), condition = c(cyl = 8)) |> 
  plot(residuals = TRUE)

patchwork::wrap_plots(p3, p3_gge) & coord_cartesian(ylim = c(5, 50))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf


# LMM ---------------------------------------------------------------------

stroop <- afex::stroop |> 
  subset(study == 1 & acc == 1 & pno %in% sample(pno, size = 10))

mod <- lme4::lmer(rt ~ condition + (condition | pno),
                  data = stroop)

grid2 <- datagrid(model = mod, condition = unique, pno = unique) |> 
  residualize_over_grid2(mod)

pj <- position_jitter(0.2, 0, seed = 3)
(p4 <- plot_predictions(mod, condition = "condition") + 
    geom_point(aes(condition, estimate), data = grid2, position = pj))


# Compare to effects
eff <- effects::predictorEffect("condition", mod, partial.residuals = TRUE)
plot(eff)

Created on 2023-02-12 with reprex v2.0.2

vincentarelbundock commented 1 year ago

@mattansb I played with this code a bit and it is super cool. Thanks for taking the time.

I thought about for a while, and I think that at this point I've half-concluded that this is still out of scope.

My overarching design goal is to keep the number of arguments and functions limited, but to make sure that they work for all supported models. This is important because I want to avoid scope creep, make sure I have enough free time to polish and maintain the full feature set, and because I'd like to have as consistent an interface as possible across models.

Partial residual plots are obviously super useful and important for lm and glm models, but it's not clear to me how one would generalize the functions above to support all 70+ model types.

Also, I've intentionally kept the number of dependencies quite small to allow people to write wrapper with enhanced functionality. Daniel has already started to do that by integrating some of the delta method hypothesis testing in ggeffects.

Anyway, just wanted to write this to make sure you knew I thought about this seriously and that I really appreciate the time you took to write it all up.

(Maybe this could be in a vignette somewhere?)

mattansb commented 1 year ago

@vincentarelbundock Thanks for the update and the detailed response.

(Just BTW - the function I wrote should already work for any model supported by {marginaleffects} and {insight} that provides a residuals() method.)

I understand where you are coming from - I guess I can't expect {marginaleffects} to literally do everything (:

I think you leave me no choice but to write a "partial residual plots with {marginaleffects}" blog post!

vincentarelbundock commented 1 year ago

I think you leave me no choice but to write a "partial residual plots with {marginaleffects}" blog post!

Yesssss!